Exemple #1
0
def test_query_history(rotkehlchen_api_server_with_exchanges):
    """Test that the history processing REST API endpoint works. Similar to test_history.py"""
    rotki = rotkehlchen_api_server_with_exchanges.rest_api.rotkehlchen
    setup = prepare_rotki_for_history_processing_test(
        rotki,
        should_mock_history_processing=False,
    )

    # Query history processing to start the history processing
    with ExitStack() as stack:
        for manager in setup:
            if manager is None:
                continue
            stack.enter_context(manager)
        response = requests.get(
            api_url_for(rotkehlchen_api_server_with_exchanges, "historyprocessingresource"),
        )

    # Simply check that the results got returned here. The actual correctness of
    # accounting results is checked in other tests such as test_simple_accounting
    assert_proper_response(response)
    data = response.json()
    assert data['message'] == ''
    assert len(data['result']) == 2
    overview = data['result']['overview']
    assert len(overview) == 10
    assert overview["loan_profit"] is not None
    assert overview["margin_positions_profit_loss"] is not None
    assert overview["settlement_losses"] is not None
    assert overview["ethereum_transaction_gas_costs"] is not None
    assert overview["asset_movement_fees"] is not None
    assert overview["general_trade_profit_loss"] is not None
    assert overview["taxable_trade_profit_loss"] is not None
    assert overview["total_taxable_profit_loss"] is not None
    assert overview["total_profit_loss"] is not None
    assert overview["defi_profit_loss"] is not None
    all_events = data['result']['all_events']
    assert isinstance(all_events, list)
    # TODO: These events are not actually checked anywhere for correctness
    #       A test should probably be made for their correctness, even though
    #       they are assumed correct if the overview is correct
    assert len(all_events) == 36

    # And now make sure that warnings have also been generated for the query of
    # the unsupported/unknown assets
    warnings = rotki.msg_aggregator.consume_warnings()
    assert len(warnings) == 13
    assert 'poloniex trade with unknown asset NOEXISTINGASSET' in warnings[0]
    assert 'poloniex trade with unsupported asset BALLS' in warnings[1]
    assert 'withdrawal of unknown poloniex asset IDONTEXIST' in warnings[2]
    assert 'withdrawal of unsupported poloniex asset DIS' in warnings[3]
    assert 'deposit of unknown poloniex asset IDONTEXIST' in warnings[4]
    assert 'deposit of unsupported poloniex asset EBT' in warnings[5]
    assert 'poloniex loan with unsupported asset BDC' in warnings[6]
    assert 'poloniex loan with unknown asset NOTEXISTINGASSET' in warnings[7]
    assert 'bittrex trade with unsupported asset PTON' in warnings[8]
    assert 'bittrex trade with unknown asset IDONTEXIST' in warnings[9]
    assert 'kraken trade with unknown asset IDONTEXISTTOO' in warnings[10]
    assert 'unknown kraken asset IDONTEXIST. Ignoring its deposit/withdrawals' in warnings[11]
    msg = 'unknown kraken asset IDONTEXISTEITHER. Ignoring its deposit/withdrawals query'
    assert msg in warnings[12]

    errors = rotki.msg_aggregator.consume_errors()
    assert len(errors) == 3
    assert 'bittrex trade with unprocessable pair %$#%$#%#$%' in errors[0]
    assert 'kraken trade with unprocessable pair IDONTEXISTZEUR' in errors[1]
    assert 'kraken trade with unprocessable pair %$#%$#%$#%$#%$#%' in errors[2]
Exemple #2
0
def test_DtaleFlask():
    from dtale.app import DtaleFlask, REAPER_TIMEOUT

    with ExitStack() as stack:
        mock_run = stack.enter_context(
            mock.patch("flask.Flask.run", mock.Mock()))
        stack.enter_context(
            mock.patch("socket.gethostname", mock.Mock(return_value="test")))
        mock_timer = stack.enter_context(mock.patch("dtale.app.Timer"))

        tmp = DtaleFlask("dtale", static_url_path="", url="http://test:9999")
        tmp.run(port="9999")

        assert tmp.reaper_on
        assert tmp.shutdown_url == "http://test:9999/shutdown"
        mock_timer.assert_called_once()
        args, _ = mock_timer.call_args
        assert args[0] == REAPER_TIMEOUT
        mock_run.assert_called_once()
        assert tmp.reaper is not None
        timer_instance = mock_timer.return_value
        timer_instance.start.assert_called_once()

        tmp.clear_reaper()
        timer_instance.cancel.assert_called_once()

    with ExitStack() as stack:
        mock_run = stack.enter_context(
            mock.patch("flask.Flask.run", mock.Mock()))
        stack.enter_context(
            mock.patch("socket.gethostname", mock.Mock(return_value="test")))
        mock_timer = stack.enter_context(
            mock.patch("dtale.app.Timer", mock.Mock()))

        tmp = DtaleFlask("dtale", static_url_path="", reaper_on=False)
        tmp.run(port="9999")

        mock_run.assert_called_once()
        assert not tmp.reaper_on
        mock_timer.assert_not_called()

    with ExitStack() as stack:
        mock_run = stack.enter_context(
            mock.patch("flask.Flask.run", mock.Mock()))
        stack.enter_context(
            mock.patch("socket.gethostname", mock.Mock(return_value="test")))
        mock_timer = stack.enter_context(
            mock.patch("dtale.app.Timer", mock.Mock()))

        tmp = DtaleFlask("dtale", static_url_path="")
        tmp.run(debug=True, port="9999")

        mock_run.assert_called_once()
        assert not tmp.reaper_on
        mock_timer.assert_not_called()

    with ExitStack() as stack:
        stack.enter_context(
            mock.patch("socket.gethostname", mock.Mock(return_value="test")))
        tmp = DtaleFlask("dtale",
                         static_url_path="",
                         url="http://test:9999",
                         app_root="/test_route/")
        assert tmp.url_for("static", "test_path") == "/test_route/test_path"
        assert (tmp.url_for("static", "test_path",
                            filename="test_file") == "/test_route/test_file")
Exemple #3
0
def main():
    """Simple entry point"""
    args = args_parsing()

    # Populate our internal singletons once and for all.
    Processes()
    Environment()
    configuration = Configuration(config_path=args.config_path)

    # From configuration, gather the entries user-configured.
    available_entries = configuration.get('entries')
    if available_entries is None:
        # If none were specified, lazy-mimic a full-enabled entries list without any configuration.
        available_entries = [{
            'type': entry_name
        } for entry_name in Entries.__members__.keys()]

    output = Output(preferred_distribution=args.distribution,
                    format_to_json=args.json)

    # We will map this function onto our enabled entries to instantiate them.
    def _entry_instantiator(entry: dict) -> Optional[Entry]:
        # Based on **required** `type` field, instantiate the corresponding `Entry` object.
        try:
            return Entries[entry.pop('type')].value(
                name=entry.pop('name', None),  # `name` is fully-optional.
                options=
                entry  # Remaining fields should be propagated as options.
            )
        except KeyError as key_error:
            print(
                'Warning: One entry (misses or) uses an invalid `type` field ({}).'
                .format(key_error),
                file=sys.stderr)
            return None

    # Let's use a context manager stack to manage conditional use of `TheadPoolExecutor`.
    with ExitStack() as cm_stack:
        if not configuration.get('parallel_loading'):
            mapper = map
        else:
            # Instantiate a threads pool to load our enabled entries in parallel.
            # We use threads (and not processes) since most work done by our entries is IO-bound.
            # `max_workers` is manually computed to mimic Python 3.8+ behaviour, but for our needs.
            #   See <https://github.com/python/cpython/pull/13618>.
            executor = cm_stack.enter_context(
                ThreadPoolExecutor(max_workers=min(
                    len(available_entries) or 1, (os.cpu_count() or 1) + 4)))
            mapper = executor.map

        for entry_instance in mapper(_entry_instantiator, available_entries):
            if not entry_instance:
                continue

            output.add_entry(entry_instance)

    output.output()

    # Has the screenshot flag been specified ?
    if args.screenshot is not None:
        # If so, but still _falsy_, pass `None` as no output file has been specified by the user.
        try:
            take_screenshot((args.screenshot or None))
        except KeyboardInterrupt:
            print()
Exemple #4
0
            print(x,y)
        f1.write('haha')
        f2.write('shii')

    with open('file1' + dt.strftime( '%Y_%m_%d_%H_%M_%S_%f' ), 'r') as f:
        lines = f.readlines()
        print(lines)
from contextlib import nested #这个包是python2中的,不使用python3
with nested(open('file1'), open('file2'), open('file3')) as (f1,f2,f3):
 for i in f1:
  j = f2.readline()
  k = f3.readline()
  print(i,j,k)
# py3
from contextlib import ExitStack
with ExitStack() as stack:
    files = [stack.enter_context(open(fname)) for fname in filenames]
    # Do something with "files"
# json
# loads/dumps 得到字符串, laod/dump 直接写入文件
# python的json.dumps方法默认会输出成这种格式"\u535a\u5ba2\u56ed",
# json.dumps({'text':"中文"},ensure_ascii=False,indent=2)

# 重新加载包
# import importlib
# importlib.reload(AugDataset)

# 字符串
# 遍历每个unicode词
str = '我是abc'
str = str.decode('utf-8')
Exemple #5
0
def test_show_ngrok(unittest, builtin_pkg):
    from dtale.app import show, get_instance, instances
    import dtale.views as views
    import dtale.global_state as global_state

    orig_import = __import__
    mock_flask_ngrok = mock.Mock()
    mock_flask_ngrok._run_ngrok = lambda: "ngrok_host"

    def import_mock(name, *args, **kwargs):
        if name == "flask_ngrok":
            return mock_flask_ngrok
        return orig_import(name, *args, **kwargs)

    test_data = pd.DataFrame([dict(a=1, b=2)])
    with ExitStack() as stack:
        stack.enter_context(
            mock.patch("{}.__import__".format(builtin_pkg),
                       side_effect=import_mock))
        stack.enter_context(mock.patch("dtale.app.USE_NGROK", True))
        stack.enter_context(mock.patch("dtale.app.PY3", True))
        mock_run = stack.enter_context(
            mock.patch("dtale.app.DtaleFlask.run", mock.Mock()))
        stack.enter_context(
            mock.patch("dtale.app.is_up", mock.Mock(return_value=False)))
        mock_requests = stack.enter_context(
            mock.patch("requests.get", mock.Mock()))
        instance = show(data=test_data,
                        subprocess=False,
                        name="foo",
                        ignore_duplicate=True)
        assert "http://ngrok_host" == instance._url
        mock_run.assert_called_once()

        pdt.assert_frame_equal(instance.data, test_data)
        tmp = test_data.copy()
        tmp["biz"] = 2.5
        instance.data = tmp
        unittest.assertEqual(
            global_state.DTYPES[instance._data_id],
            views.build_dtypes_state(tmp),
            "should update app data/dtypes",
        )

        instance2 = get_instance(instance._data_id)
        assert instance2._url == instance._url
        instances()

        assert get_instance(
            20) is None  # should return None for invalid data ids

        instance.kill()
        mock_requests.assert_called_once()
        assert mock_requests.call_args[0][0] == "http://ngrok_host/shutdown"
        assert global_state.METADATA["1"]["name"] == "foo"

    with ExitStack() as stack:
        stack.enter_context(mock.patch("dtale.app.USE_NGROK", True))
        stack.enter_context(mock.patch("dtale.app.PY3", False))
        with pytest.raises(Exception):
            show(data=test_data)
Exemple #6
0
def train(args: argparse.Namespace):
    if args.dry_run:
        # Modify arguments so that we write to a temporary directory and
        # perform 0 training iterations
        temp_dir = tempfile.TemporaryDirectory(
        )  # Will be automatically removed
        args.output = temp_dir.name
        args.max_updates = 0

    utils.seedRNGs(args.seed)

    check_arg_compatibility(args)
    output_folder = os.path.abspath(args.output)
    resume_training = check_resume(args, output_folder)

    global logger
    logger = setup_main_logger(__name__,
                               file_logging=True,
                               console=not args.quiet,
                               path=os.path.join(output_folder, C.LOG_NAME))
    utils.log_basic_info(args)
    arguments.save_args(args, os.path.join(output_folder, C.ARGS_STATE_NAME))

    max_seq_len_source, max_seq_len_target = args.max_seq_len
    # The maximum length is the length before we add the BOS/EOS symbols
    max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
    max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
    logger.info(
        "Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
        max_seq_len_source, max_seq_len_target)

    with ExitStack() as exit_stack:
        context = determine_context(args, exit_stack)

        train_iter, eval_iter, config_data, source_vocabs, target_vocab = create_data_iters_and_vocabs(
            args=args,
            max_seq_len_source=max_seq_len_source,
            max_seq_len_target=max_seq_len_target,
            shared_vocab=use_shared_vocab(args),
            resume_training=resume_training,
            output_folder=output_folder)
        max_seq_len_source = config_data.max_seq_len_source
        max_seq_len_target = config_data.max_seq_len_target

        # Dump the vocabularies if we're just starting up
        if not resume_training:
            vocab.save_source_vocabs(source_vocabs, output_folder)
            vocab.save_target_vocab(target_vocab, output_folder)

        source_vocab_sizes = [len(v) for v in source_vocabs]
        target_vocab_size = len(target_vocab)
        logger.info('Vocabulary sizes: source=[%s] target=%d',
                    '|'.join([str(size) for size in source_vocab_sizes]),
                    target_vocab_size)

        model_config = create_model_config(
            args=args,
            source_vocab_sizes=source_vocab_sizes,
            target_vocab_size=target_vocab_size,
            max_seq_len_source=max_seq_len_source,
            max_seq_len_target=max_seq_len_target,
            config_data=config_data)
        model_config.freeze()

        training_model = create_training_model(config=model_config,
                                               context=context,
                                               output_dir=output_folder,
                                               train_iter=train_iter,
                                               args=args)

        # Handle options that override training settings
        min_updates = args.min_updates
        max_updates = args.max_updates
        min_samples = args.min_samples
        max_samples = args.max_samples
        max_num_checkpoint_not_improved = args.max_num_checkpoint_not_improved
        min_epochs = args.min_num_epochs
        max_epochs = args.max_num_epochs
        if min_epochs is not None and max_epochs is not None:
            check_condition(
                min_epochs <= max_epochs,
                "Minimum number of epochs must be smaller than maximum number of epochs"
            )
        # Fixed training schedule always runs for a set number of updates
        if args.learning_rate_schedule:
            min_updates = None
            max_updates = sum(num_updates
                              for (_,
                                   num_updates) in args.learning_rate_schedule)
            max_num_checkpoint_not_improved = -1
            min_samples = None
            max_samples = None
            min_epochs = None
            max_epochs = None

        trainer = training.EarlyStoppingTrainer(
            model=training_model,
            optimizer_config=create_optimizer_config(args, source_vocab_sizes),
            max_params_files_to_keep=args.keep_last_params,
            source_vocabs=source_vocabs,
            target_vocab=target_vocab)

        trainer.fit(train_iter=train_iter,
                    validation_iter=eval_iter,
                    early_stopping_metric=args.optimized_metric,
                    metrics=args.metrics,
                    checkpoint_frequency=args.checkpoint_frequency,
                    max_num_not_improved=max_num_checkpoint_not_improved,
                    min_samples=min_samples,
                    max_samples=max_samples,
                    min_updates=min_updates,
                    max_updates=max_updates,
                    min_epochs=min_epochs,
                    max_epochs=max_epochs,
                    lr_decay_param_reset=args.learning_rate_decay_param_reset,
                    lr_decay_opt_states_reset=args.
                    learning_rate_decay_optimizer_states_reset,
                    decoder=create_checkpoint_decoder(args, exit_stack,
                                                      context),
                    mxmonitor_pattern=args.monitor_pattern,
                    mxmonitor_stat_func=args.monitor_stat_func,
                    allow_missing_parameters=args.allow_missing_params
                    or model_config.lhuc,
                    existing_parameters=args.params)
Exemple #7
0
def score(args: argparse.Namespace):
    global logger
    logger = setup_main_logger(__name__, file_logging=False)

    utils.log_basic_info(args)

    with ExitStack() as exit_stack:
        context = utils.determine_context(
            device_ids=args.device_ids,
            use_cpu=args.use_cpu,
            disable_device_locking=args.disable_device_locking,
            lock_dir=args.lock_dir,
            exit_stack=exit_stack)
        if args.batch_type == C.BATCH_TYPE_SENTENCE:
            check_condition(
                args.batch_size % len(context) == 0,
                "When using multiple devices the batch size must be "
                "divisible by the number of devices. Choose a batch "
                "size that is a multiple of %d." % len(context))
        logger.info("Scoring Device(s): %s",
                    ", ".join(str(c) for c in context))

        # This call has a number of different parameters compared to training which reflect our need to get scores
        # one-for-one and in the same order as the input data.
        # To enable code reuse, we stuff the `args` parameter with some values.
        # Bucketing and permuting need to be turned off in order to preserve the ordering of sentences.
        # The 'zeros' fill_up strategy fills underfilled buckets with zeros which can then be used to find the last item.
        # Finally, 'resume_training' needs to be set to True because it causes the model to be loaded instead of initialized.
        args.no_bucketing = True
        args.fill_up = 'zeros'
        args.bucket_width = 10
        score_iter, config_data, source_vocabs, target_vocab, model_config = get_data_iters_and_vocabs(
            args=args, model_folder=args.model)

        scoring_model = scoring.ScoringModel(
            config=model_config,
            model_dir=args.model,
            context=context,
            provide_data=score_iter.provide_data,
            provide_label=score_iter.provide_label,
            default_bucket_key=score_iter.default_bucket_key,
            score_type=args.score_type,
            bucketing=False,
            length_penalty=inference.LengthPenalty(
                alpha=args.length_penalty_alpha,
                beta=args.length_penalty_beta),
            softmax_temperature=args.softmax_temperature)

        scorer = scoring.Scorer(scoring_model, source_vocabs, target_vocab)

        scorer.score(score_iter=score_iter,
                     output_handler=get_output_handler(
                         output_type=args.output_type,
                         output_fname=args.output))

        if config_data.data_statistics.num_discarded != 0:
            num_discarded = config_data.data_statistics.num_discarded
            logger.warning(
                'Warning: %d %s longer than %s %s skipped. '
                'As a result, the output won\'t be parallel with the input. '
                'Increase the maximum length (--max-seq-len M:N) or trim your training data.',
                num_discarded, utils.inflect('sentence',
                                             num_discarded), args.max_seq_len,
                utils.inflect('was', num_discarded))
Exemple #8
0
def test_query_online_trade_history_case_2(mock_bitfinex):
    """Test pagination logic for trades works as expected when a request
    returns a result already processed in the previous request.

    Other things tested:
      - Stop requesting when number of entries is less than limit.

    First request: 2 results
    Second request: 2 results, both trades are repeated from the 1st request.
    Third request: 2 results, first trade is repeated from the 2nd request.
    Fourth request: 1 result

    Trades with id 1 to 4 are expected to be returned.
    """
    api_limit = 2
    mock_bitfinex.first_connection = MagicMock()
    mock_bitfinex.currency_map = {
        'UST': 'USDt',
        'WBT': 'WBTC',
    }
    mock_bitfinex.pair_bfx_symbols_map = {
        'ETHUST': ('ETH', 'UST'),
        'WBTUSD': ('WBT', 'USD'),
        'ETHEUR': ('ETH', 'EUR'),
    }
    # Buy ETH with USDT
    trade_1 = """
    [
        1,
        "tETH:UST",
        1606899600000,
        10,
        0.26334268,
        187.37,
        "LIMIT",
        null,
        -1,
        -0.09868591,
        "UST"
    ]
    """
    # Sell ETH for USDT
    trade_2 = """
    [
        2,
        "tETH:UST",
        1606901400000,
        20,
        -0.26334268,
        187.37,
        "LIMIT",
        null,
        -1,
        -0.09868591,
        "ETH"
    ]
    """
    # Buy WBTC for USD
    trade_3 = """
    [
        3,
        "tWBTUSD",
        1606932000000,
        30,
        10000.00000000,
        0.00005000,
        "LIMIT",
        null,
        -1,
        -20.00000000,
        "USD"
    ]
    """
    # Sell WBTC for USD
    trade_4 = """
    [
        4,
        "tWBTUSD",
        1606986000000,
        40,
        -10000.00000000,
        0.00005000,
        "LIMIT",
        null,
        -1,
        -20.00000000,
        "WBT"
    ]
    """

    def get_paginated_response():
        results = [
            f'[{trade_1},{trade_2}]',
            f'[{trade_1},{trade_2}]',  # repeated line
            f'[{trade_2},{trade_3}]',  # contains repeated
            f'[{trade_4}]',
        ]
        for result_ in results:
            yield result_

    def mock_api_query_response(endpoint, options):  # pylint: disable=unused-argument
        return MockResponse(HTTPStatus.OK, next(get_response))

    get_response = get_paginated_response()
    api_limit_patch = patch(
        target='rotkehlchen.exchanges.bitfinex.API_TRADES_MAX_LIMIT',
        new=api_limit,
    )
    api_query_patch = patch.object(
        target=mock_bitfinex,
        attribute='_api_query',
        side_effect=mock_api_query_response,
    )
    with ExitStack() as stack:
        stack.enter_context(api_limit_patch)
        stack.enter_context(api_query_patch)
        trades = mock_bitfinex.query_online_trade_history(
            start_ts=Timestamp(0),
            end_ts=Timestamp(int(datetime.now().timestamp())),
        )
        expected_trades = [
            Trade(
                timestamp=Timestamp(1606899600),
                location=Location.BITFINEX,
                pair=TradePair('ETH_USDT'),
                trade_type=TradeType.BUY,
                amount=AssetAmount(FVal('0.26334268')),
                rate=Price(FVal('187.37')),
                fee=Fee(FVal('0.09868591')),
                fee_currency=Asset('USDT'),
                link='1',
                notes='',
            ),
            Trade(
                timestamp=Timestamp(1606901400),
                location=Location.BITFINEX,
                pair=TradePair('ETH_USDT'),
                trade_type=TradeType.SELL,
                amount=AssetAmount(FVal('0.26334268')),
                rate=Price(FVal('187.37')),
                fee=Fee(FVal('0.09868591')),
                fee_currency=Asset('ETH'),
                link='2',
                notes='',
            ),
            Trade(
                timestamp=Timestamp(1606932000),
                location=Location.BITFINEX,
                pair=TradePair('WBTC_USD'),
                trade_type=TradeType.BUY,
                amount=AssetAmount(FVal('10000.0')),
                rate=Price(FVal('0.00005')),
                fee=Fee(FVal('20.0')),
                fee_currency=Asset('USD'),
                link='3',
                notes='',
            ),
            Trade(
                timestamp=Timestamp(1606986000),
                location=Location.BITFINEX,
                pair=TradePair('WBTC_USD'),
                trade_type=TradeType.SELL,
                amount=AssetAmount(FVal('10000.0')),
                rate=Price(FVal('0.00005')),
                fee=Fee(FVal('20.0')),
                fee_currency=Asset('WBTC'),
                link='4',
                notes='',
            ),
        ]
        assert trades == expected_trades
Exemple #9
0
def test_query_online_deposits_withdrawals_case_1(mock_bitfinex):
    """Test pagination logic for asset movements works as expected when each
    request does not return a result already processed.

    Other things tested:
      - Results are sorted by id in ascending mode.
      - Skip result when status is not 'COMPLETED'.
      - Stop requesting (break the loop) when result timestamp is greater than
      'end_ts'.
      - '_api_query' call arguments.

    First request: 2 results
    Second request: 2 results, 1 not completed.
    Third request: 1 result, out of time range (its timestamp is gt 'end_ts')

    Movements with id 1, 2 and 4 are expected to be returned.
    """
    api_limit = 2
    mock_bitfinex.first_connection = MagicMock()
    mock_bitfinex.currency_map = {'WBT': 'WBTC'}
    # Deposit WBTC
    movement_1 = """
    [
        1,
        "WBT",
        "Wrapped Bitcoin",
        null,
        null,
        1606899600000,
        1606899700000,
        null,
        null,
        "COMPLETED",
        null,
        null,
        0.26300954,
        -0.00135,
        null,
        null,
        "DESTINATION_ADDRESS",
        null,
        null,
        null,
        "TRANSACTION_ID",
        null
    ]
    """
    # Withdraw WBTC
    movement_2 = """
    [
        2,
        "WBT",
        "Wrapped Bitcoin",
        null,
        null,
        1606901400000,
        1606901500000,
        null,
        null,
        "COMPLETED",
        null,
        null,
        -0.26300954,
        -0.00135,
        null,
        null,
        "DESTINATION_ADDRESS",
        null,
        null,
        null,
        "TRANSACTION_ID",
        null
    ]
    """
    # Deposit WBTC, not completed
    movement_3 = """
    [
        3,
        "WBT",
        "Wrapped Bitcoin",
        null,
        null,
        1606932000000,
        1606932100000,
        null,
        null,
        "WHATEVER",
        null,
        null,
        0.26300954,
        -0.00135,
        null,
        null,
        "DESTINATION_ADDRESS",
        null,
        null,
        null,
        "TRANSACTION_ID",
        null
    ]
    """
    # Withdraw EUR
    movement_4 = """
    [
        4,
        "EUR",
        "Euro",
        null,
        null,
        1606986000000,
        1606986100000,
        null,
        null,
        "COMPLETED",
        null,
        null,
        -0.26300954,
        -0.00135,
        null,
        null,
        "",
        null,
        null,
        null,
        "",
        null
    ]
    """
    # Deposit WBTC, outside time range (gt 'end_ts')
    movement_5 = """
    [
        5,
        "WBT",
        "Wrapped Bitcoin",
        null,
        null,
        1606996801000,
        1606996901000,
        null,
        null,
        "COMPLETED",
        null,
        null,
        0.26300954,
        -0.00135,
        null,
        null,
        "DESTINATION_ADDRESS",
        null,
        null,
        null,
        "TRANSACTION_ID",
        null
    ]
    """
    expected_calls = [
        call(
            endpoint='movements',
            options={
                'start': 0,
                'end': 1606996800000,
                'limit': 2,
            },
        ),
        call(
            endpoint='movements',
            options={
                'start': 1606901400000,
                'end': 1606996800000,
                'limit': 2,
            },
        ),
        call(
            endpoint='movements',
            options={
                'start': 1606986000000,
                'end': 1606996800000,
                'limit': 2,
            },
        ),
    ]

    def get_paginated_response():
        results = [
            f'[{movement_2},{movement_1}]',
            f'[{movement_4},{movement_3}]',
            f'[{movement_5}]',
        ]
        for result_ in results:
            yield result_

    def mock_api_query_response(endpoint, options):  # pylint: disable=unused-argument
        return MockResponse(HTTPStatus.OK, next(get_response))

    get_response = get_paginated_response()
    api_limit_patch = patch(
        target='rotkehlchen.exchanges.bitfinex.API_MOVEMENTS_MAX_LIMIT',
        new=api_limit,
    )
    api_query_patch = patch.object(
        target=mock_bitfinex,
        attribute='_api_query',
        side_effect=mock_api_query_response,
    )
    with ExitStack() as stack:
        stack.enter_context(api_limit_patch)
        api_query_mock = stack.enter_context(api_query_patch)
        asset_movements = mock_bitfinex.query_online_deposits_withdrawals(
            start_ts=Timestamp(0),
            end_ts=Timestamp(int(datetime.now().timestamp())),
        )
        assert api_query_mock.call_args_list == expected_calls

        wbtc_fee_asset = Asset('WBTC')
        eur_fee_asset = Asset('EUR')
        expected_asset_movements = [
            AssetMovement(
                timestamp=Timestamp(1606899600),
                location=Location.BITFINEX,
                category=AssetMovementCategory.DEPOSIT,
                address='DESTINATION_ADDRESS',
                transaction_id='TRANSACTION_ID',
                asset=wbtc_fee_asset,
                amount=FVal('0.26300954'),
                fee_asset=wbtc_fee_asset,
                fee=Fee(FVal('0.00135')),
                link=str(1),
            ),
            AssetMovement(
                timestamp=Timestamp(1606901400),
                location=Location.BITFINEX,
                category=AssetMovementCategory.WITHDRAWAL,
                address='DESTINATION_ADDRESS',
                transaction_id='TRANSACTION_ID',
                asset=wbtc_fee_asset,
                amount=FVal('0.26300954'),
                fee_asset=wbtc_fee_asset,
                fee=Fee(FVal('0.00135')),
                link=str(2),
            ),
            AssetMovement(
                timestamp=Timestamp(1606986000),
                location=Location.BITFINEX,
                category=AssetMovementCategory.WITHDRAWAL,
                address=None,
                transaction_id=None,
                asset=eur_fee_asset,
                amount=FVal('0.26300954'),
                fee_asset=eur_fee_asset,
                fee=Fee(FVal('0.00135')),
                link=str(4),
            ),
        ]
        assert asset_movements == expected_asset_movements
Exemple #10
0
def test_query_online_deposits_withdrawals_case_2(mock_bitfinex):
    """Test pagination logic for asset movements works as expected when a
    request returns a result already processed in the previous request.

    Other things tested:
      - Stop requesting when number of entries is less than limit.

    First request: 2 results
    Second request: 2 results, both movements are repeated from the 1st request.
    Third request: 2 results, first movement is repeated from the 2nd request.
    Fourth request: 1 result

    Trades with id 1 to 4 are expected to be returned.
    """
    api_limit = 2
    mock_bitfinex.first_connection = MagicMock()
    mock_bitfinex.currency_map = {'WBT': 'WBTC'}
    # Deposit WBTC
    movement_1 = """
    [
        1,
        "WBT",
        "Wrapped Bitcoin",
        null,
        null,
        1606899600000,
        1606899700000,
        null,
        null,
        "COMPLETED",
        null,
        null,
        0.26300954,
        -0.00135,
        null,
        null,
        "DESTINATION_ADDRESS",
        null,
        null,
        null,
        "TRANSACTION_ID",
        null
    ]
    """
    # Withdraw WBTC
    movement_2 = """
    [
        2,
        "WBT",
        "Wrapped Bitcoin",
        null,
        null,
        1606901400000,
        1606901500000,
        null,
        null,
        "COMPLETED",
        null,
        null,
        -0.26300954,
        -0.00135,
        null,
        null,
        "DESTINATION_ADDRESS",
        null,
        null,
        null,
        "TRANSACTION_ID",
        null
    ]
    """
    # Withdraw EUR
    movement_3 = """
    [
        3,
        "EUR",
        "Euro",
        null,
        null,
        1606986000000,
        1606986100000,
        null,
        null,
        "COMPLETED",
        null,
        null,
        -0.26300954,
        -0.00135,
        null,
        null,
        "",
        null,
        null,
        null,
        "",
        null
    ]
    """
    # Deposit WBTC
    movement_4 = """
    [
        4,
        "WBT",
        "Wrapped Bitcoin",
        null,
        null,
        1606996800000,
        1606996900000,
        null,
        null,
        "COMPLETED",
        null,
        null,
        0.26300954,
        -0.00135,
        null,
        null,
        "DESTINATION_ADDRESS",
        null,
        null,
        null,
        "TRANSACTION_ID",
        null
    ]
    """

    def get_paginated_response():
        results = [
            f'[{movement_2},{movement_1}]',
            f'[{movement_2},{movement_1}]',
            f'[{movement_3},{movement_2}]',
            f'[{movement_4}]',
        ]
        for result_ in results:
            yield result_

    def mock_api_query_response(endpoint, options):  # pylint: disable=unused-argument
        return MockResponse(HTTPStatus.OK, next(get_response))

    get_response = get_paginated_response()
    api_limit_patch = patch(
        target='rotkehlchen.exchanges.bitfinex.API_MOVEMENTS_MAX_LIMIT',
        new=api_limit,
    )
    api_query_patch = patch.object(
        target=mock_bitfinex,
        attribute='_api_query',
        side_effect=mock_api_query_response,
    )
    with ExitStack() as stack:
        stack.enter_context(api_limit_patch)
        stack.enter_context(api_query_patch)
        asset_movements = mock_bitfinex.query_online_deposits_withdrawals(
            start_ts=Timestamp(0),
            end_ts=Timestamp(int(datetime.now().timestamp())),
        )
        wbtc_fee_asset = Asset('WBTC')
        eur_fee_asset = Asset('EUR')
        expected_asset_movements = [
            AssetMovement(
                timestamp=Timestamp(1606899600),
                location=Location.BITFINEX,
                category=AssetMovementCategory.DEPOSIT,
                address='DESTINATION_ADDRESS',
                transaction_id='TRANSACTION_ID',
                asset=wbtc_fee_asset,
                amount=FVal('0.26300954'),
                fee_asset=wbtc_fee_asset,
                fee=Fee(FVal('0.00135')),
                link=str(1),
            ),
            AssetMovement(
                timestamp=Timestamp(1606901400),
                location=Location.BITFINEX,
                category=AssetMovementCategory.WITHDRAWAL,
                address='DESTINATION_ADDRESS',
                transaction_id='TRANSACTION_ID',
                asset=wbtc_fee_asset,
                amount=FVal('0.26300954'),
                fee_asset=wbtc_fee_asset,
                fee=Fee(FVal('0.00135')),
                link=str(2),
            ),
            AssetMovement(
                timestamp=Timestamp(1606986000),
                location=Location.BITFINEX,
                category=AssetMovementCategory.WITHDRAWAL,
                address=None,
                transaction_id=None,
                asset=eur_fee_asset,
                amount=FVal('0.26300954'),
                fee_asset=eur_fee_asset,
                fee=Fee(FVal('0.00135')),
                link=str(3),
            ),
            AssetMovement(
                timestamp=Timestamp(1606996800),
                location=Location.BITFINEX,
                category=AssetMovementCategory.DEPOSIT,
                address='DESTINATION_ADDRESS',
                transaction_id='TRANSACTION_ID',
                asset=wbtc_fee_asset,
                amount=FVal('0.26300954'),
                fee_asset=wbtc_fee_asset,
                fee=Fee(FVal('0.00135')),
                link=str(4),
            ),
        ]
        assert asset_movements == expected_asset_movements
Exemple #11
0
def test_query_online_trade_history_case_1(mock_bitfinex):
    """Test pagination logic for trades works as expected when each request
    does not return a result already processed.

    Other things tested:
      - Stop requesting (break the loop) when result timestamp is greater than
      'end_ts'.
      - '_api_query' call arguments.

    First request: 2 results
    Second request: 2 results
    Third request: 1 result, out of time range (its timestamp is gt 'end_ts')

    Trades with id 1 to 4 are expected to be returned.
    """
    api_limit = 2
    mock_bitfinex.first_connection = MagicMock()
    mock_bitfinex.currency_map = {
        'UST': 'USDt',
        'WBT': 'WBTC',
    }
    mock_bitfinex.pair_bfx_symbols_map = {
        'ETHUST': ('ETH', 'UST'),
        'WBTUSD': ('WBT', 'USD'),
        'ETHEUR': ('ETH', 'EUR'),
    }
    # Buy ETH with USDT
    trade_1 = """
    [
        1,
        "tETH:UST",
        1606899600000,
        10,
        0.26334268,
        187.37,
        "LIMIT",
        null,
        -1,
        -0.09868591,
        "USD"
    ]
    """
    # Sell ETH for USDT
    trade_2 = """
    [
        2,
        "tETH:UST",
        1606901400000,
        20,
        -0.26334268,
        187.37,
        "LIMIT",
        null,
        -1,
        -0.09868591,
        "ETH"
    ]
    """
    # Buy WBTC for USD
    trade_3 = """
    [
        3,
        "tWBTUSD",
        1606932000000,
        30,
        10000.00000000,
        0.00005000,
        "LIMIT",
        null,
        -1,
        -20.00000000,
        "USD"
    ]
    """
    # Sell WBTC for USD
    trade_4 = """
    [
        4,
        "tWBTUSD",
        1606986000000,
        40,
        -10000.00000000,
        0.00005000,
        "LIMIT",
        null,
        -1,
        -20.00000000,
        "BTC"
    ]
    """
    # Sell ETH for EUR, outside time range (gt 'end_ts')
    trade_5 = """
    [
        5,
        "tETH:EUR",
        1606996801000,
        50,
        -0.26334268,
        163.29,
        "LIMIT",
        null,
        -1,
        -0.09868591,
        "ETH"
    ]
    """
    expected_calls = [
        call(
            endpoint='trades',
            options={
                'start': 0,
                'end': 1606996800000,
                'limit': 2,
                'sort': 1,
            },
        ),
        call(
            endpoint='trades',
            options={
                'start': 1606901400000,
                'end': 1606996800000,
                'limit': 2,
                'sort': 1,
            },
        ),
        call(
            endpoint='trades',
            options={
                'start': 1606986000000,
                'end': 1606996800000,
                'limit': 2,
                'sort': 1,
            },
        ),
    ]

    def get_paginated_response():
        results = [
            f'[{trade_1},{trade_2}]',
            f'[{trade_3},{trade_4}]',
            f'[{trade_5}]',
        ]
        for result_ in results:
            yield result_

    def mock_api_query_response(endpoint, options):  # pylint: disable=unused-argument
        return MockResponse(HTTPStatus.OK, next(get_response))

    get_response = get_paginated_response()
    api_limit_patch = patch(
        target='rotkehlchen.exchanges.bitfinex.API_TRADES_MAX_LIMIT',
        new=api_limit,
    )
    api_query_patch = patch.object(
        target=mock_bitfinex,
        attribute='_api_query',
        side_effect=mock_api_query_response,
    )
    with ExitStack() as stack:
        stack.enter_context(api_limit_patch)
        api_query_mock = stack.enter_context(api_query_patch)
        trades = mock_bitfinex.query_online_trade_history(
            start_ts=Timestamp(0),
            end_ts=Timestamp(int(datetime.now().timestamp())),
        )
        assert api_query_mock.call_args_list == expected_calls
        expected_trades = [
            Trade(
                timestamp=Timestamp(1606899600),
                location=Location.BITFINEX,
                pair=TradePair('ETH_USDT'),
                trade_type=TradeType.BUY,
                amount=AssetAmount(FVal('0.26334268')),
                rate=Price(FVal('187.37')),
                fee=Fee(FVal('0.09868591')),
                fee_currency=Asset('USD'),
                link='1',
                notes='',
            ),
            Trade(
                timestamp=Timestamp(1606901400),
                location=Location.BITFINEX,
                pair=TradePair('ETH_USDT'),
                trade_type=TradeType.SELL,
                amount=AssetAmount(FVal('0.26334268')),
                rate=Price(FVal('187.37')),
                fee=Fee(FVal('0.09868591')),
                fee_currency=Asset('ETH'),
                link='2',
                notes='',
            ),
            Trade(
                timestamp=Timestamp(1606932000),
                location=Location.BITFINEX,
                pair=TradePair('WBTC_USD'),
                trade_type=TradeType.BUY,
                amount=AssetAmount(FVal('10000.0')),
                rate=Price(FVal('0.00005')),
                fee=Fee(FVal('20.0')),
                fee_currency=Asset('USD'),
                link='3',
                notes='',
            ),
            Trade(
                timestamp=Timestamp(1606986000),
                location=Location.BITFINEX,
                pair=TradePair('WBTC_USD'),
                trade_type=TradeType.SELL,
                amount=AssetAmount(FVal('10000.0')),
                rate=Price(FVal('0.00005')),
                fee=Fee(FVal('20.0')),
                fee_currency=Asset('BTC'),
                link='4',
                notes='',
            ),
        ]
        assert trades == expected_trades
Exemple #12
0
    def test_202__push_not_forced(
        self,
        settings,
        client,
        repository_factory,
        git_hub_repository_factory,
        project_factory,
        task_factory,
    ):
        settings.GITHUB_HOOK_SECRET = b""
        with ExitStack() as stack:
            gh = stack.enter_context(patch("metecho.api.models.gh"))
            gh.get_repo_info.return_value = MagicMock(
                **{
                    "pull_requests.return_value": (
                        MagicMock(number=123, closed_at=None, is_merged=False,)
                        for _ in range(1)
                    ),
                    "compare_commits.return_value": MagicMock(ahead_by=0),
                }
            )
            gh.normalize_commit.return_value = "1234abcd"

            repo = repository_factory(repo_id=123)
            git_hub_repository_factory(repo_id=123)
            project = project_factory(repository=repo, branch_name="test-project")
            task = task_factory(project=project, branch_name="test-task")

            refresh_commits_job = stack.enter_context(
                patch("metecho.api.jobs.refresh_commits_job")
            )
            response = client.post(
                reverse("hook"),
                json.dumps(
                    {
                        "ref": "refs/heads/test-task",
                        "forced": False,
                        "repository": {"id": 123},
                        "commits": [
                            {
                                "id": "123",
                                "author": {
                                    "name": "Test",
                                    "email": "*****@*****.**",
                                    "username": "******",
                                },
                                "timestamp": "2019-11-20 21:32:53.668260+00:00",
                                "message": "Message",
                                "url": "https://github.com/test/user/foo",
                            }
                        ],
                        "sender": {
                            "login": "******",
                            "avatar_url": "https://avatar_url/",
                        },
                    }
                ),
                content_type="application/json",
                # The sha1 hexdigest of the request body x the secret
                # key above:
                HTTP_X_HUB_SIGNATURE="sha1=6a5d470ca262a2522635f1adb71a13b18446dd54",
                HTTP_X_GITHUB_EVENT="push",
            )
            assert response.status_code == 202, response.content
            assert not refresh_commits_job.delay.called
            task.refresh_from_db()
            assert len(task.commits) == 1
    def test_write_ipv6_rhel(self):
        rh_distro = self._get_distro('rhel')

        write_bufs = {}

        def replace_write(filename, content, mode=0o644, omode="wb"):
            buf = WriteBuffer()
            buf.mode = mode
            buf.omode = omode
            buf.write(content)
            write_bufs[filename] = buf

        with ExitStack() as mocks:
            mocks.enter_context(
                mock.patch.object(util, 'write_file', replace_write))
            mocks.enter_context(
                mock.patch.object(util, 'load_file', return_value=''))
            mocks.enter_context(
                mock.patch.object(os.path, 'isfile', return_value=False))

            rh_distro.apply_network(BASE_NET_CFG_IPV6, False)

            self.assertEquals(len(write_bufs), 4)
            self.assertIn('/etc/sysconfig/network-scripts/ifcfg-lo',
                          write_bufs)
            write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-lo']
            expected_buf = '''
DEVICE="lo"
ONBOOT=yes
'''
            self.assertCfgEquals(expected_buf, str(write_buf))
            self.assertEquals(write_buf.mode, 0o644)

            self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth0',
                          write_bufs)
            write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth0']
            expected_buf = '''
DEVICE="eth0"
BOOTPROTO="static"
NETMASK="255.255.255.0"
IPADDR="192.168.1.5"
ONBOOT=yes
GATEWAY="192.168.1.254"
BROADCAST="192.168.1.0"
IPV6INIT=yes
IPV6ADDR="2607:f0d0:1002:0011::2"
IPV6_DEFAULTGW="2607:f0d0:1002:0011::1"
'''
            self.assertCfgEquals(expected_buf, str(write_buf))
            self.assertEquals(write_buf.mode, 0o644)
            self.assertIn('/etc/sysconfig/network-scripts/ifcfg-eth1',
                          write_bufs)
            write_buf = write_bufs['/etc/sysconfig/network-scripts/ifcfg-eth1']
            expected_buf = '''
DEVICE="eth1"
BOOTPROTO="static"
NETMASK="255.255.255.0"
IPADDR="192.168.1.6"
ONBOOT=no
GATEWAY="192.168.1.254"
BROADCAST="192.168.1.0"
IPV6INIT=yes
IPV6ADDR="2607:f0d0:1002:0011::3"
IPV6_DEFAULTGW="2607:f0d0:1002:0011::1"
'''
            self.assertCfgEquals(expected_buf, str(write_buf))
            self.assertEquals(write_buf.mode, 0o644)

            self.assertIn('/etc/sysconfig/network', write_bufs)
            write_buf = write_bufs['/etc/sysconfig/network']
            expected_buf = '''
# Created by cloud-init v. 0.7
NETWORKING=yes
NETWORKING_IPV6=yes
IPV6_AUTOCONF=no
'''
            self.assertCfgEquals(expected_buf, str(write_buf))
            self.assertEquals(write_buf.mode, 0o644)
Exemple #14
0
def acquire_gpus(requested_device_ids: List[int], lock_dir: str = "/tmp",
                 retry_wait_min: int = 10, retry_wait_rand: int = 60,
                 num_gpus_available: Optional[int]=None):
    """
    Acquire a number of GPUs in a transactional way. This method should be used inside a `with` statement.
    Will try to acquire all the requested number of GPUs. If currently
    not enough GPUs are available all locks will be released and we wait until we retry. Will retry until enough
    GPUs become available.

    :param requested_device_ids: The requested device ids, each number is either negative indicating the number of GPUs
     that will be allocated, or positive indicating we want to acquire a specific device id.
    :param lock_dir: The directory for storing the lock file.
    :param retry_wait_min: The minimum number of seconds to wait between retries.
    :param retry_wait_rand: Randomly add between 0 and `retry_wait_rand` seconds to the wait time.
    :param num_gpus_available: The number of GPUs available, if None we will call get_num_gpus().
    :return: yields a list of GPU ids.
    """
    if num_gpus_available is None:
        num_gpus_available = get_num_gpus()
    if num_gpus_available == 0:
        raise RuntimeError("Can not acquire GPU, as no GPUs were found on this machine.")

    if not os.path.exists(lock_dir):
        raise IOError("Lock directory %s does not exist." % lock_dir)

    if not os.access(lock_dir, os.W_OK):
        raise IOError("Lock directory %s is not writeable." % lock_dir)

    # split the device ids into the specific ids requested and count up the number of arbitrary ids we want
    # e.g. device_ids = [-3, 2, 5, 7, -5] means we want to acquire device 2, 5 and 7 plus 8 other devices.
    specific_device_ids = set()  # type: Set[int]
    num_arbitrary_device_ids = 0
    for device_id in requested_device_ids:
        if device_id < 0:
            num_gpus = -device_id
            num_arbitrary_device_ids += num_gpus
        else:
            if device_id in specific_device_ids:
                raise ValueError("Requested GPU %d twice." % device_id)
            specific_device_ids.add(device_id)

    # make sure we have enough GPUs available
    num_gpus_requested = len(specific_device_ids) + num_arbitrary_device_ids
    if num_gpus_requested > num_gpus_available:
        raise ValueError("Requested %d GPUs, but only %d are available." % (num_gpus_requested, num_gpus_available))
    logger.info("Attempting to acquire %d GPUs of %d GPUs. The requested devices are: %s",
                num_gpus_requested, num_gpus_available, str(requested_device_ids))

    # note: it's important to first allocate the specific device ids and then the others to not deadlock ourselves.

    # for specific device ids we just have the device id itself as a candidate
    candidates_to_request = [[device_id] for device_id in specific_device_ids]

    # for the arbitrary device ids we take all remaining device ids as a list of candidates
    remaining_device_ids = [device_id for device_id in range(num_gpus_available)
                            if device_id not in specific_device_ids]
    candidates_to_request += [remaining_device_ids for _ in range(num_arbitrary_device_ids)]

    while True:

        with ExitStack() as exit_stack:
            any_failed = False
            acquired_gpus = []  # type: List[int]
            with GpuFileLock(candidates=["master_lock"], lock_dir=lock_dir) as master_lock:  # type: str
                # Only one process, determined by the master lock, can try acquiring gpu locks at a time.
                # This will make sure that we use consecutive device ids whenever possible.
                if master_lock is not None:
                    for candidates in candidates_to_request:
                        gpu_id = exit_stack.enter_context(GpuFileLock(candidates=candidates, lock_dir=lock_dir))
                        if gpu_id is not None:
                            acquired_gpus.append(cast(int, gpu_id))
                        else:
                            if len(candidates) == 1:
                                logger.info("Could not acquire GPU %d. It's currently locked.", candidates[0])
                            any_failed = True
                            break
            if master_lock is not None and not any_failed:
                try:
                    yield acquired_gpus
                except:
                    raise
                return

        # randomize so that multiple processes starting at the same time don't retry at a similar point in time
        if retry_wait_rand > 0:
            retry_wait_actual = retry_wait_min + random.randint(0, retry_wait_rand)
        else:
            retry_wait_actual = retry_wait_min

        if master_lock is None:
            logger.info("Another process is acquiring GPUs at the moment will try again in %ss." % retry_wait_actual)
        else:
            logger.info("Not enough GPUs available will try again in %ss." % retry_wait_actual)
        time.sleep(retry_wait_actual)
Exemple #15
0
    def __init__(self,
                 model_folder: str,
                 inputs: List[str],
                 references: List[str],
                 source_vocabs: List[vocab.Vocab],
                 target_vocabs: List[vocab.Vocab],
                 model: sockeye.model.SockeyeModel,
                 context: mx.Context,
                 max_input_len: Optional[int] = None,
                 batch_size: int = 16,
                 beam_size: int = C.DEFAULT_BEAM_SIZE,
                 nbest_size: int = C.DEFAULT_NBEST_SIZE,
                 bucket_width_source: int = 10,
                 length_penalty_alpha: float = 1.0,
                 length_penalty_beta: float = 0.0,
                 max_output_length_num_stds: int = C.
                 DEFAULT_NUM_STD_MAX_OUTPUT_LENGTH,
                 ensemble_mode: str = 'linear',
                 sample_size: int = -1,
                 random_seed: int = 42,
                 hybridize: bool = True) -> None:
        self.max_input_len = max_input_len
        self.max_output_length_num_stds = max_output_length_num_stds
        self.ensemble_mode = ensemble_mode
        self.beam_size = beam_size
        self.nbest_size = nbest_size
        self.batch_size = batch_size
        self.bucket_width_source = bucket_width_source
        self.length_penalty_alpha = length_penalty_alpha
        self.length_penalty_beta = length_penalty_beta
        self.model = model

        with ExitStack() as exit_stack:
            inputs_fins = [
                exit_stack.enter_context(data_io.smart_open(f)) for f in inputs
            ]
            references_fins = [
                exit_stack.enter_context(data_io.smart_open(f))
                for f in references
            ]

            inputs_sentences = [f.readlines() for f in inputs_fins]
            targets_sentences = [f.readlines() for f in references_fins]

            utils.check_condition(
                all(
                    len(l) == len(targets_sentences[0])
                    for l in chain(inputs_sentences, targets_sentences)),
                "Sentences differ in length")
            utils.check_condition(
                all(
                    len(sentence.strip()) > 0
                    for sentence in targets_sentences[0]),
                "Empty target validation sentence.")

            if sample_size <= 0:
                sample_size = len(inputs_sentences[0])
            if sample_size < len(inputs_sentences[0]):
                sentences = parallel_subsample(
                    inputs_sentences + targets_sentences, sample_size,
                    random_seed)
                self.inputs_sentences = sentences[0:len(inputs_sentences)]
                self.targets_sentences = sentences[len(inputs_sentences):]
            else:
                self.inputs_sentences, self.targets_sentences = inputs_sentences, targets_sentences

            if sample_size < self.batch_size:
                self.batch_size = sample_size
        for factor_idx, factor in enumerate(self.inputs_sentences):
            write_to_file(
                factor,
                os.path.join(model_folder,
                             C.DECODE_IN_NAME.format(factor=factor_idx)))
        for factor_idx, factor in enumerate(self.targets_sentences):
            write_to_file(
                factor,
                os.path.join(model_folder,
                             C.DECODE_REF_NAME.format(factor=factor_idx)))

        self.inputs_sentences = list(
            zip(*self.inputs_sentences))  # type: List[List[str]]

        scorer = inference.CandidateScorer(
            length_penalty_alpha=length_penalty_alpha,
            length_penalty_beta=length_penalty_beta,
            brevity_penalty_weight=0.0,
            prefix='scorer_')

        # TODO: possibly support decoding on multiple GPUs
        self.translator = inference.Translator(
            batch_size=self.batch_size,
            context=context,
            ensemble_mode=self.ensemble_mode,
            scorer=scorer,
            beam_search_stop='all',
            nbest_size=self.nbest_size,
            models=[self.model],
            source_vocabs=source_vocabs,
            target_vocabs=target_vocabs,
            restrict_lexicon=None,
            hybridize=hybridize)

        logger.info(
            "Created CheckpointDecoder(max_input_len=%d, beam_size=%d, num_sentences=%d)",
            max_input_len if max_input_len is not None else -1, beam_size,
            len(self.targets_sentences[0]))
 def nested(*contexts):
     """ Reimplementation of nested in Python 3. """
     with ExitStack() as stack:
         for ctx in contexts:
             stack.enter_context(ctx)
         yield contexts
def run_ec2_experiment(ec2, instance, ccalg, btlbw, rtt, queue_size, region, loss_rate=None, force=False):
    if loss_rate is not None:
        experiment_name = '{}-{}bw-{}rtt-{}q-{}loss-{}'.format(ccalg, btlbw, rtt, queue_size, loss_rate, region)
    else:
        experiment_name = '{}-{}bw-{}rtt-{}q-{}'.format(ccalg, btlbw, rtt, queue_size, region)
    if not force and ccalg_predict.is_completed_experiment(experiment_name):
        return
    else:
        if ccalg_predict.ran_experiment_today(experiment_name):
            return
    logging.info('Creating experiment for instance: {}-{}'.format(region, ccalg))
    instance_rtt = int(float(get_ping_rtt(instance.public_ip_address)))
    logging.info('Got instance RTT: {}'.format(instance_rtt))

    if instance_rtt >= rtt:
        logging.warning('Skipping experiment with instance RTT {} >= {}'.format(
            instance_rtt, rtt))
        return

    # server = generate_experiments.HOST_SERVER //Commented this due to TypeError: type object argument after ** must be a mapping, not Host
    server = generate_experiments.HOST_SERVER_TEMPLATE
    server = generate_experiments.HOST_SERVER
    client = generate_experiments.HOST_AWS_TEMPLATE
    client['ip_wan'] = instance.public_ip_address
    client['ip_lan'] = instance.private_ip_address
    client['key_filename'] = get_key_pair_path(ec2)

    server_nat_ip_lan = generate_experiments.HOST_CLIENT.ip_lan
    print('server_nat_ip_lan:', server_nat_ip_lan)
    server_nat_ip = generate_experiments.HOST_CLIENT.ip_wan
    print('server_nat_ip_wan:', server_nat_ip)

    client = cctestbed.Host(**client)
    print("AWS Client:", client)
    print("Server:", server)
    # server = cctestbed.Host(**server)

    cloudlab_client = generate_experiments.HOST_CLIENT #CLARIFY
    print("Cloudlab_Client:", cloudlab_client)
    # cloudlab_client['ip_wan'] = server_nat_ip
    # cloudlab_client = cctestbed.Host(**cloudlab_client)

    server_port = 5201
    client_port = 5555

    #print('Connecting dpdk')
    #cctestbed.connect_dpdk(server, client)

    flow = {'ccalg': ccalg,
            'end_time': 60,
            'rtt': rtt - instance_rtt,
            'start_time': 0}
    flows = [cctestbed.Flow(ccalg=flow['ccalg'], start_time=flow['start_time'],
                      end_time=flow['end_time'], rtt=flow['rtt'],
                      server_port=server_port, client_port=client_port,
                      client_log=None, server_log=None, kind='iperf', client = cloudlab_client)]

    exp = cctestbed.Experiment(name=experiment_name,
                               btlbw=btlbw,
                               queue_size=queue_size,
                               flows=flows, server=server, client=client,
                               config_filename='experiments-all-ccalgs-aws.yaml',
                               server_nat_ip=server_nat_ip,
                               loss_rate=loss_rate)

    try:
        # make sure old stuff closed
        exp.cleanup_last_experiment(cleanup_tail=False)
        logging.info('Running experiment: {}'.format(exp.name))
        with ExitStack() as stack:
            # add DNAT rule
            stack.enter_context(ccalg_predict.add_dnat_rule(exp, exp.client.ip_wan))
            # add route to URL
            # stack.enter_context(ccalg_predict.add_route(exp, exp.client.ip_wan,
            #                                             gateway_ip=exp.client.ip_lan)) #CLARIFY
            stack.enter_context(ccalg_predict.add_route(exp, exp.client.ip_wan,
                                                        gateway_ip=server_nat_ip_lan))
            exp._run_tcpdump('server', stack)
            exp._run_tcpdump('client', stack)
            exp._run_tcpprobe(stack)
            stack.enter_context(exp._run_rtt_monitor(program='ping'))
            exp._run_all_flows(stack, bess_config_name='active-middlebox-pmd')
        # compress all log files
        proc = exp._compress_logs_url()
        logging.info('Finished experiment: {}'.format(exp.name))
        return proc
    except Exception as e:
        logging.error('Error occurred while running experiment '+exp.name)
        exp._delete_logs(delete_description=False)
        raise e
Exemple #18
0
 def nested(*contexts):
     with ExitStack() as stack:
         for context in contexts:
             stack.enter_context(context)
         yield
Exemple #19
0
def zoom_index_gen(
    mp=None,
    out_dir=None,
    zoom=None,
    geojson=False,
    gpkg=False,
    shapefile=False,
    txt=False,
    vrt=False,
    fieldname="location",
    basepath=None,
    for_gdal=True,
    threading=False,
):
    """
    Generate indexes for given zoom level.

    Parameters
    ----------
    mp : Mapchete object
        process output to be indexed
    out_dir : path
        optionally override process output directory
    zoom : int
        zoom level to be processed
    geojson : bool
        generate GeoJSON index (default: False)
    gpkg : bool
        generate GeoPackage index (default: False)
    shapefile : bool
        generate Shapefile index (default: False)
    txt : bool
        generate tile path list textfile (default: False)
    vrt : bool
        GDAL-style VRT file (default: False)
    fieldname : str
        field name which contains paths of tiles (default: "location")
    basepath : str
        if set, use custom base path instead of output path
    for_gdal : bool
        use GDAL compatible remote paths, i.e. add "/vsicurl/" before path
        (default: True)
    """
    for zoom in get_zoom_levels(process_zoom_levels=zoom):
        with ExitStack() as es:
            # get index writers for all enabled formats
            index_writers = []
            if geojson:
                index_writers.append(
                    es.enter_context(
                        VectorFileWriter(driver="GeoJSON",
                                         out_path=_index_file_path(
                                             out_dir, zoom, "geojson"),
                                         crs=mp.config.output_pyramid.crs,
                                         fieldname=fieldname)))
            if gpkg:
                index_writers.append(
                    es.enter_context(
                        VectorFileWriter(driver="GPKG",
                                         out_path=_index_file_path(
                                             out_dir, zoom, "gpkg"),
                                         crs=mp.config.output_pyramid.crs,
                                         fieldname=fieldname)))
            if shapefile:
                index_writers.append(
                    es.enter_context(
                        VectorFileWriter(driver="ESRI Shapefile",
                                         out_path=_index_file_path(
                                             out_dir, zoom, "shp"),
                                         crs=mp.config.output_pyramid.crs,
                                         fieldname=fieldname)))
            if txt:
                index_writers.append(
                    es.enter_context(
                        TextFileWriter(
                            out_path=_index_file_path(out_dir, zoom, "txt"))))
            if vrt:
                index_writers.append(
                    es.enter_context(
                        VRTFileWriter(out_path=_index_file_path(
                            out_dir, zoom, "vrt"),
                                      output=mp.config.output,
                                      out_pyramid=mp.config.output_pyramid)))

            logger.debug("use the following index writers: %s", index_writers)

            def _worker(tile):
                # if there are indexes to write to, check if output exists
                tile_path = _tile_path(
                    orig_path=mp.config.output.get_path(tile),
                    basepath=basepath,
                    for_gdal=for_gdal)
                indexes = [
                    i for i in index_writers
                    if not i.entry_exists(tile=tile, path=tile_path)
                ]
                if indexes:
                    output_exists = mp.config.output.tiles_exist(
                        output_tile=tile)
                else:
                    output_exists = None
                return tile, tile_path, indexes, output_exists

            with concurrent.futures.ThreadPoolExecutor() as executor:
                for task in concurrent.futures.as_completed(
                    (executor.submit(_worker, i)
                     for i in mp.config.output_pyramid.tiles_from_geom(
                         mp.config.area_at_zoom(zoom), zoom))):
                    tile, tile_path, indexes, output_exists = task.result()
                    # only write entries if there are indexes to write to and output
                    # exists
                    if indexes and output_exists:
                        logger.debug("%s exists", tile_path)
                        logger.debug("write to %s indexes" % len(indexes))
                        for index in indexes:
                            index.write(tile, tile_path)
                    # yield tile for progress information
                    yield tile
Exemple #20
0
    def main(self, orings):
        for iseqs in izip(
                *
            [iring.read(guarantee=self.guarantee) for iring in self.irings]):
            if self.shutdown_event.is_set():
                break
            for i, iseq in enumerate(iseqs):
                self.sequence_proclogs[i].update(iseq.header)
            oheaders = self._on_sequence(iseqs)
            for ohdr in oheaders:
                if 'time_tag' not in ohdr:
                    ohdr['time_tag'] = self._seq_count
            self._seq_count += 1

            igulp_nframes = [
                self.gulp_nframe or iseq.header['gulp_nframe']
                for iseq in iseqs
            ]
            igulp_overlaps = self._define_input_overlap_nframe(iseqs)
            istride_nframes = igulp_nframes[:]
            igulp_nframes = [
                igulp_nframe + nframe_overlap
                for igulp_nframe, nframe_overlap in zip(
                    igulp_nframes, igulp_overlaps)
            ]

            for iseq, igulp_nframe in zip(iseqs, igulp_nframes):
                if self.buffer_factor is None:
                    src_block = iseq.ring.owner
                    if src_block is not None and self.is_fused_with(src_block):
                        buffer_factor = 1
                    else:
                        buffer_factor = None
                else:
                    buffer_factor = self.buffer_factor
                iseq.resize(gulp_nframe=igulp_nframe,
                            buf_nframe=self.buffer_nframe,
                            buffer_factor=buffer_factor)

            # TODO: Ever need to specify starting offset?
            iframe0s = [0 for _ in igulp_nframes]

            force_skip = False

            with ExitStack() as oseq_stack:
                oseqs, ogulp_overlaps = self.begin_sequences(
                    oseq_stack, orings, oheaders, igulp_nframes,
                    istride_nframes)
                if self.shutdown_event.is_set():
                    break
                prev_time = time.time()
                for ispans in izip(*[
                        iseq.read(igulp_nframe, istride_nframe, iframe0)
                        for (iseq, igulp_nframe, istride_nframe, iframe0) in
                        zip(iseqs, igulp_nframes, istride_nframes, iframe0s)
                ]):
                    if self.shutdown_event.is_set():
                        return

                    if any([ispan.nframe_skipped for ispan in ispans]):
                        # There were skipped (overwritten) frames
                        with ExitStack() as ospan_stack:
                            iskip_slices = [
                                slice(iframe0, iframe0 + ispan.nframe_skipped,
                                      istride_nframe)
                                for iframe0, istride_nframe, ispan in zip(
                                    iframe0s, istride_nframes, ispans)
                            ]
                            iskip_nframes = [
                                ispan.nframe_skipped for ispan in ispans
                            ]
                            # ***TODO: Need to loop over multiple ospans here,
                            #            because iskip_nframes can be
                            #            arbitrarily large!
                            ospans = self.reserve_spans(
                                ospan_stack, oseqs, iskip_nframes)
                            ostrides_actual = self._on_skip(
                                iskip_slices, ospans)
                            device.stream_synchronize()
                            self.commit_spans(ospans, ostrides_actual,
                                              ogulp_overlaps)

                    if all([ispan.nframe == 0 for ispan in ispans]):
                        # No data to see here, move right along
                        continue

                    cur_time = time.time()
                    acquire_time = cur_time - prev_time
                    prev_time = cur_time

                    with ExitStack() as ospan_stack:
                        igulp_nframes = [ispan.nframe for ispan in ispans]
                        ospans = self.reserve_spans(ospan_stack, oseqs,
                                                    igulp_nframes)
                        cur_time = time.time()
                        reserve_time = cur_time - prev_time
                        prev_time = cur_time

                        if not force_skip:
                            # *TODO: See if can fuse together multiple on_data calls here before
                            #          calling stream_synchronize().
                            #        Consider passing .data instead of rings here
                            ostrides_actual = self._on_data(ispans, ospans)
                            device.stream_synchronize()

                        any_frames_overwritten = any(
                            [ispan.nframe_overwritten for ispan in ispans])
                        if force_skip or any_frames_overwritten:
                            # Note: To allow interrupted pipelines to catch up,
                            #         we force-skip an additional gulp whenever
                            #         a span is overwritten during on_data.
                            force_skip = any_frames_overwritten
                            iskip_slices = [
                                slice(
                                    ispan.frame_offset, ispan.frame_offset +
                                    ispan.nframe_overwritten, istride_nframe)
                                for ispan, istride_nframe in zip(
                                    ispans, istride_nframes)
                            ]
                            ostrides_actual = self._on_skip(
                                iskip_slices, ospans)
                            device.stream_synchronize()

                        self.commit_spans(ospans, ostrides_actual,
                                          ogulp_overlaps)
                    cur_time = time.time()
                    process_time = cur_time - prev_time
                    prev_time = cur_time
                    self.perf_proclog.update({
                        'acquire_time': acquire_time,
                        'reserve_time': reserve_time,
                        'process_time': process_time
                    })
            # **TODO: This will not be called if an exception is raised
            #           Need to call it from a context manager somehow
            self._on_sequence_end(iseqs)
Exemple #21
0
def dedup_tracked1(ds, comm1):
    size = comm1.size
    ds.tt.update(comm1=comm1, size=size)
    by_mh = defaultdict(list)
    for inode in comm1.inodes:
        # XXX Need to cope with deleted inodes.
        # We cannot find them in the search-new pass, not without doing
        # some tracking of directory modifications to poke updated
        # directories to find removed elements.

        # rehash everytime for now
        # I don't know enough about how inode transaction numbers are
        # updated (as opposed to extent updates) to be able to actually
        # cache the result
        with ds.open_by_inode(inode) as rfile:
            if rfile is None:
                continue
            try:
                by_mh[mini_hash_from_file(inode, rfile)].append(inode)
            except IOError as e:
                if e.errno == errno.EIO:
                    ds.tt.notify('%r has IO errors, skipping' % inode)
                    continue
                raise
            ds.tt.update(mhash=None)

    for inodes in by_mh.values():
        inode_count = len(inodes)
        if inode_count < 2:
            continue
        fies = set()
        for inode in inodes:
            with ds.open_by_inode(inode) as rfile:
                if rfile is None:
                    continue
                fies.add(fiemap_hash_from_file(rfile))

        if len(fies) < 2:
            continue

        files = []
        fds = []
        # For description only
        fd_names = {}
        fd_inodes = {}
        by_hash = defaultdict(list)

        # XXX I have no justification for doubling inode_count
        ofile_req = 2 * inode_count + ds.ofile_reserved
        if ofile_req > ds.ofile_soft:
            if ofile_req <= ds.ofile_hard:
                resource.setrlimit(resource.RLIMIT_OFILE,
                                   (ofile_req, ds.ofile_hard))
                ds.ofile_soft = ofile_req
            else:
                ds.tt.notify(
                    'Too many duplicates (%d at size %d), '
                    'would bring us over the open files limit (%d, %d).' %
                    (inode_count, size, ds.ofile_soft, ds.ofile_hard))
                for inode in inodes:
                    if inode.has_updates:
                        ds.skip(inode)
                continue

        for inode in inodes:
            # Open everything rw, we can't pick one for the source side
            # yet because the crypto hash might eliminate it.
            # We may also want to defragment the source.
            try:
                path = inode.vol.live.lookup_one_path(inode)
            except IOError as e:
                if e.errno == errno.ENOENT:
                    ds.sess.delete(inode)
                    continue
                raise
            try:
                afile = fopenat_rw(inode.vol.live.fd, path)
            except IOError as e:
                if e.errno == errno.ETXTBSY:
                    # The file contains the image of a running process,
                    # we can't open it in write mode.
                    ds.tt.notify('File %r is busy, skipping' % path)
                elif e.errno == errno.EACCES:
                    # Could be SELinux or immutability
                    ds.tt.notify('Access denied on %r, skipping' % path)
                elif e.errno == errno.ENOENT:
                    # The file was moved or unlinked by a racing process
                    ds.tt.notify('File %r may have moved, skipping' % path)
                else:
                    raise
                ds.skip(inode)
                continue

            # It's not completely guaranteed we have the right inode,
            # there may still be race conditions at this point.
            # Gets re-checked below (tell and fstat).
            fd = afile.fileno()
            fd_inodes[fd] = inode
            fd_names[fd] = path
            files.append(afile)
            fds.append(fd)

        with ExitStack() as stack:
            for afile in files:
                stack.enter_context(closing(afile))
            # Enter this context last
            immutability = stack.enter_context(ImmutableFDs(fds))

            # With a false positive, some kind of cmp pass that compares
            # all files at once might be more efficient that hashing.
            for afile in files:
                fd = afile.fileno()
                inode = fd_inodes[fd]
                if fd in immutability.fds_in_write_use:
                    ds.tt.notify('File %r is in use, skipping' % fd_names[fd])
                    ds.skip(inode)
                    continue
                hasher = hashlib.sha1()
                try:
                    for buf in iter(lambda: afile.read(BUFSIZE), b''):
                        hasher.update(buf)
                except OSError as e:
                    if e.errno == errno.EIO:
                        continue
                    raise

                # Gets rid of a race condition
                st = os.fstat(fd)
                if st.st_ino != inode.ino:
                    ds.skip(inode)
                    continue
                if st.st_dev != inode.vol.live.st_dev:
                    ds.skip(inode)
                    continue

                size1 = afile.tell()
                if size1 != size:
                    if size1 < inode.vol.size_cutoff:
                        # if we didn't delete this inode, it would cause
                        # spurious comm groups in all future invocations.
                        ds.sess.delete(inode)
                    else:
                        ds.skip(inode)
                    continue

                by_hash[hasher.digest()].append(afile)
                ds.tt.update(fhash=None)

            for fileset in by_hash.values():
                dedup_fileset(ds, fileset, fd_names, fd_inodes, size)
def adversarial_discriminator_relation_triple(net, layers,lamda_w,scope='adversary', neighbor_num=1, leaky=False,reuse=False,dropout_keep=1.0):
    flip_gradient = FlipGradientBuilder()
    source_ft_gan, source_ft_gan_differ1, source_ft_gan_differ2 = net

    source_ft_gan = flip_gradient(source_ft_gan, lamda_w)
    source_ft_gan_differ1 = flip_gradient(source_ft_gan_differ1, lamda_w)
    source_ft_gan_differ2 = flip_gradient(source_ft_gan_differ2, lamda_w)

    def LeakyReLU(x, alpha=0.2):
        return tf.maximum(alpha*x, x)
    if leaky:
        activation_fn = LeakyReLU
    else:
        activation_fn = tf.nn.relu


    with tf.variable_scope(scope) as scope:
        if reuse:
                scope.reuse_variables()
        with ExitStack() as stack:
            stack.enter_context(tf.variable_scope(scope))
            stack.enter_context(
                slim.arg_scope(
                    [slim.fully_connected],
                    activation_fn=activation_fn,
                    weights_regularizer=slim.l2_regularizer(2.5e-5)))

            layer_s_s_differ1 = slim.fully_connected(tf.concat([source_ft_gan,source_ft_gan_differ1],axis=1),
                                                     num_outputs=int(layers[0]/3),
                                              activation_fn=activation_fn,
                                              scope="adver_layer1_1")
            layer_s_s_differ1 = tf.nn.dropout(layer_s_s_differ1, keep_prob=0.9)
            layer_s_s_differ1 = slim.fully_connected(layer_s_s_differ1,
                                                     num_outputs=int(layers[0] / 3),
                                                     activation_fn=activation_fn,
                                                     scope="adver_layer1_2")
            layer_s_s_differ1 = tf.nn.dropout(layer_s_s_differ1, keep_prob=0.9)

            layer_s_s_differ2 = slim.fully_connected(tf.concat([source_ft_gan, source_ft_gan_differ2], axis=1),
                                                     num_outputs=int(layers[0]/3 ),
                                                     activation_fn=activation_fn,
                                                     scope="adver_layer1_12")
            layer_s_s_differ2 = tf.nn.dropout(layer_s_s_differ2, keep_prob=0.9)
            layer_s_s_differ2 = slim.fully_connected(layer_s_s_differ2,
                                                     num_outputs=int(layers[0] / 3),
                                                     activation_fn=activation_fn,
                                                     scope="adver_layer1_22")
            layer_s_s_differ2 = tf.nn.dropout(layer_s_s_differ2, keep_prob=0.9)

            layer_s_s_differ1_differ2 = slim.fully_connected(tf.concat([source_ft_gan_differ2, source_ft_gan_differ1], axis=1),
                                                     num_outputs=int(layers[0]/3 ),
                                                     activation_fn=activation_fn,
                                                     scope="adver_layer1_13")
            layer_s_s_differ1_differ2 = tf.nn.dropout(layer_s_s_differ1_differ2, keep_prob=0.9)
            layer_s_s_differ1_differ2 = slim.fully_connected(layer_s_s_differ1_differ2,
                num_outputs=int(layers[0] / 3),
                activation_fn=activation_fn,
                scope="adver_layer1_23")
            layer_s_s_differ1_differ2 = tf.nn.dropout(layer_s_s_differ1_differ2, keep_prob=0.9)

            layer_s_s_final = slim.fully_connected(tf.concat([layer_s_s_differ1, layer_s_s_differ2,layer_s_s_differ1_differ2],1),
                                                     num_outputs=int(layers[1]),
                                                     activation_fn=activation_fn,
                                                     scope="adver_layer2")
            net = tf.nn.dropout(layer_s_s_final, keep_prob=0.8)
            net_1 = slim.fully_connected(net, 1, activation_fn=None, scope="adver_layer3")
            net_2 = slim.fully_connected(net, 2, activation_fn=None, scope="adver_layer4")
    return net_1,net_2
Exemple #23
0
def test_show(unittest, builtin_pkg):
    from dtale.app import show, get_instance, instances
    import dtale.views as views
    import dtale.global_state as global_state

    class MockDtaleFlask(Flask):
        def __init__(self,
                     import_name,
                     reaper_on=True,
                     url=None,
                     app_root=None,
                     *args,
                     **kwargs):
            kwargs.pop("instance_relative_config", None)
            kwargs.pop("static_url_path", None)
            super(MockDtaleFlask, self).__init__(import_name, *args, **kwargs)

        def run(self, *args, **kwargs):
            pass

    instances()
    test_data = pd.DataFrame([dict(a=1, b=2)])
    with ExitStack() as stack:
        mock_run = stack.enter_context(
            mock.patch("dtale.app.DtaleFlask.run", mock.Mock()))
        mock_find_free_port = stack.enter_context(
            mock.patch("dtale.app.find_free_port",
                       mock.Mock(return_value=9999)))
        stack.enter_context(
            mock.patch("socket.gethostname",
                       mock.Mock(return_value="localhost")))
        stack.enter_context(
            mock.patch("dtale.app.is_up", mock.Mock(return_value=False)))
        mock_requests = stack.enter_context(
            mock.patch("requests.get", mock.Mock()))
        instance = show(data=test_data,
                        subprocess=False,
                        name="foo",
                        ignore_duplicate=True)
        print(instance.main_url())
        assert "http://localhost:9999" == instance._url
        assert "http://localhost:9999/dtale/main/foo" == instance.main_url()
        mock_run.assert_called_once()
        mock_find_free_port.assert_called_once()

        pdt.assert_frame_equal(instance.data, test_data)
        tmp = test_data.copy()
        tmp["biz"] = 2.5
        instance.data = tmp
        unittest.assertEqual(
            global_state.DTYPES[instance._data_id],
            views.build_dtypes_state(tmp),
            "should update app data/dtypes",
        )

        instance2 = get_instance(instance._data_id)
        assert instance2._url == instance._url
        instance2 = get_instance("foo")
        assert instance2._url == instance._url
        pdt.assert_frame_equal(instance2.data, tmp)

        instances()

        assert get_instance(
            20) is None  # should return None for invalid data ids

        instance.kill()
        mock_requests.assert_called_once()
        assert mock_requests.call_args[0][
            0] == "http://localhost:9999/shutdown"
        assert global_state.METADATA["1"]["name"] == "foo"

        instance3 = show(data=test_data,
                         subprocess=False,
                         name="It's Here",
                         ignore_duplicate=True)
        assert instance3.main_url(
        ) == "http://localhost:9999/dtale/main/its_here"
        pdt.assert_frame_equal(instance3.data, test_data)

    with ExitStack() as stack:
        mock_run = stack.enter_context(
            mock.patch("dtale.app.DtaleFlask.run", mock.Mock()))
        mock_find_free_port = stack.enter_context(
            mock.patch("dtale.app.find_free_port",
                       mock.Mock(return_value=9999)))
        stack.enter_context(
            mock.patch("socket.gethostname",
                       mock.Mock(return_value="localhost")))
        stack.enter_context(
            mock.patch("dtale.app.is_up", mock.Mock(return_value=False)))
        mock_data_loader = mock.Mock(return_value=test_data)
        instance = show(
            data_loader=mock_data_loader,
            subprocess=False,
            port=9999,
            force=True,
            debug=True,
            ignore_duplicate=True,
        )
        assert "http://localhost:9999" == instance._url
        mock_run.assert_called_once()
        mock_find_free_port.assert_not_called()
        mock_data_loader.assert_called_once()
        _, kwargs = mock_run.call_args

        assert "9999" in instance._url

    with ExitStack() as stack:
        mock_run = stack.enter_context(
            mock.patch("dtale.app.DtaleFlask.run", mock.Mock()))
        stack.enter_context(
            mock.patch("dtale.app.find_free_port",
                       mock.Mock(return_value=9999)))
        stack.enter_context(
            mock.patch("socket.gethostname",
                       mock.Mock(return_value="localhost")))
        stack.enter_context(
            mock.patch("dtale.app.is_up", mock.Mock(return_value=True)))
        mock_data_loader = mock.Mock(return_value=test_data)
        mock_webbrowser = stack.enter_context(mock.patch("webbrowser.get"))
        instance = show(
            data_loader=mock_data_loader,
            subprocess=False,
            port=9999,
            open_browser=True,
            ignore_duplicate=True,
        )
        mock_run.assert_not_called()
        webbrowser_instance = mock_webbrowser.return_value
        assert ("http://localhost:9999/dtale/main/4" ==
                webbrowser_instance.open.call_args[0][0])
        instance.open_browser()
        assert ("http://localhost:9999/dtale/main/4" ==
                webbrowser_instance.open.mock_calls[1][1][0])

    # RangeIndex test
    test_data = pd.DataFrame([1, 2, 3])
    with ExitStack() as stack:
        stack.enter_context(mock.patch("dtale.app.DtaleFlask", MockDtaleFlask))
        stack.enter_context(
            mock.patch("dtale.app.find_free_port",
                       mock.Mock(return_value=9999)))
        stack.enter_context(
            mock.patch("socket.gethostname",
                       mock.Mock(return_value="localhost")))
        stack.enter_context(
            mock.patch("dtale.app.is_up", mock.Mock(return_value=False)))
        stack.enter_context(mock.patch("dtale.app.logger", mock.Mock()))
        instance = show(data=test_data,
                        subprocess=False,
                        name="foo",
                        ignore_duplicate=True)
        assert np.array_equal(instance.data["0"].values, test_data[0].values)

    with ExitStack() as stack:
        stack.enter_context(mock.patch("dtale.app.DtaleFlask", MockDtaleFlask))
        stack.enter_context(
            mock.patch("dtale.app.find_free_port",
                       mock.Mock(return_value=9999)))
        stack.enter_context(
            mock.patch("socket.gethostname",
                       mock.Mock(return_value="localhost")))
        stack.enter_context(
            mock.patch("dtale.app.is_up", mock.Mock(return_value=False)))
        stack.enter_context(mock.patch("dtale.app.logger", mock.Mock()))
        stack.enter_context(
            mock.patch("dtale.views.in_ipython_frontend",
                       mock.Mock(return_value=False)))

        get_calls = {"ct": 0}
        getter = namedtuple("get", "ok")

        def mock_requests_get(url, verify=True):
            if url.endswith("/health"):
                is_ok = get_calls["ct"] > 0
                get_calls["ct"] += 1
                return getter(is_ok)
            return getter(True)

        stack.enter_context(mock.patch("requests.get", mock_requests_get))
        mock_display = stack.enter_context(
            mock.patch("IPython.display.display", mock.Mock()))
        mock_iframe = stack.enter_context(
            mock.patch("IPython.display.IFrame", mock.Mock()))
        instance = show(
            data=test_data,
            subprocess=True,
            name="foo",
            notebook=True,
            ignore_duplicate=True,
        )
        mock_display.assert_called_once()
        mock_iframe.assert_called_once()
        assert mock_iframe.call_args[0][
            0] == "http://localhost:9999/dtale/iframe/6"

        assert type(instance.__str__()).__name__ == "str"
        assert type(instance.__repr__()).__name__ == "str"

    class MockDtaleFlaskRunTest(Flask):
        def __init__(self,
                     import_name,
                     reaper_on=True,
                     url=None,
                     app_root=None,
                     *args,
                     **kwargs):
            kwargs.pop("instance_relative_config", None)
            kwargs.pop("static_url_path", None)
            super(MockDtaleFlaskRunTest,
                  self).__init__(import_name, *args, **kwargs)

        def run(self, *args, **kwargs):
            assert self.jinja_env.auto_reload
            assert self.config["TEMPLATES_AUTO_RELOAD"]

    with mock.patch("dtale.app.DtaleFlask", MockDtaleFlaskRunTest):
        show(
            data=test_data,
            subprocess=False,
            port=9999,
            debug=True,
            ignore_duplicate=True,
        )

    with mock.patch("dtale.app._thread.start_new_thread",
                    mock.Mock()) as mock_thread:
        show(data=test_data, subprocess=True, ignore_duplicate=True)
        mock_thread.assert_called()

    test_data = pd.DataFrame([dict(a=1, b=2)])

    with ExitStack() as stack:
        mock_build_app = stack.enter_context(
            mock.patch("dtale.app.build_app", mock.Mock()))
        stack.enter_context(
            mock.patch("dtale.app.find_free_port",
                       mock.Mock(return_value=9999)))
        stack.enter_context(
            mock.patch("socket.gethostname",
                       mock.Mock(return_value="localhost")))
        stack.enter_context(
            mock.patch("dtale.app.is_up", mock.Mock(return_value=False)))
        stack.enter_context(mock.patch("requests.get", mock.Mock()))
        show(data=test_data,
             subprocess=False,
             name="foo",
             ignore_duplicate=True)

        _, kwargs = mock_build_app.call_args
        unittest.assertEqual(
            {
                "app_root": None,
                "host": "localhost",
                "reaper_on": True
            },
            kwargs,
            "build_app should be called with defaults",
        )

    # test adding duplicate column
    with ExitStack() as stack:
        stack.enter_context(mock.patch("dtale.app.DtaleFlask", MockDtaleFlask))
        stack.enter_context(
            mock.patch("dtale.app.find_free_port",
                       mock.Mock(return_value=9999)))
        stack.enter_context(
            mock.patch("socket.gethostname",
                       mock.Mock(return_value="localhost")))
        stack.enter_context(
            mock.patch("dtale.app.is_up", mock.Mock(return_value=False)))
        stack.enter_context(mock.patch("requests.get", mock.Mock()))
        instance = show(
            data=pd.DataFrame([dict(a=1, b=2)]),
            subprocess=False,
            name="foo",
            ignore_duplicate=True,
        )
        with pytest.raises(Exception):
            instance.data = instance.data.rename(columns={"b": "a"})

        curr_instance_ct = len(global_state.DATA)
        show(data=pd.DataFrame([dict(a=1, b=2)]), subprocess=False, name="foo")
        assert curr_instance_ct == len(global_state.DATA)

    # cleanup
    global_state.cleanup()
Exemple #24
0
def daemon(
    *,
    test: bool,
    verbose: bool,
    config: str,
    pidfile: str,
    logfile: str,
    exporter_listen_host: str
):
    """The main program of afancontrol."""

    logging.basicConfig(level=logging.DEBUG if verbose else logging.INFO)

    config_path = Path(config)
    daemon_cli_config = DaemonCLIConfig(
        pidfile=pidfile, logfile=logfile, exporter_listen_host=exporter_listen_host
    )
    parsed_config = parse_config(config_path, daemon_cli_config)

    if parsed_config.daemon.exporter_listen_host:
        metrics: Metrics = PrometheusMetrics(parsed_config.daemon.exporter_listen_host)
    else:
        metrics = NullMetrics()

    manager = Manager(
        arduino_connections=parsed_config.arduino_connections,
        fans=parsed_config.fans,
        readonly_fans=parsed_config.readonly_fans,
        temps=parsed_config.temps,
        mappings=parsed_config.mappings,
        report=Report(report_command=parsed_config.report_cmd),
        triggers_config=parsed_config.triggers,
        metrics=metrics,
    )

    pidfile_instance: Optional[PidFile] = None
    if parsed_config.daemon.pidfile is not None:
        pidfile_instance = PidFile(parsed_config.daemon.pidfile)

    if test:
        print("Config file '%s' is good" % config_path)
        return

    if parsed_config.daemon.logfile:
        # Logging to file should not be configured when running in
        # the config test mode.
        file_handler = logging.FileHandler(parsed_config.daemon.logfile)
        file_handler.setFormatter(
            logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s")
        )
        logging.getLogger().addHandler(file_handler)

    signals = Signals()
    signal.signal(signal.SIGTERM, signals.sigterm)
    signal.signal(signal.SIGQUIT, signals.sigterm)
    signal.signal(signal.SIGINT, signals.sigterm)
    signal.signal(signal.SIGHUP, signals.sigterm)

    with ExitStack() as stack:
        if pidfile_instance is not None:
            stack.enter_context(pidfile_instance)
            pidfile_instance.save_pid(os.getpid())

        stack.enter_context(manager)

        # Make a first tick. If something is wrong, (e.g. bad fan/temp
        # file paths), an exception would be raised here.
        manager.tick()

        while not signals.wait_for_term_queued(parsed_config.daemon.interval):
            manager.tick()
Exemple #25
0
def test_show_jupyter_server_proxy(unittest):
    from dtale.app import show, get_instance, instances
    import dtale.views as views
    import dtale.global_state as global_state

    test_data = pd.DataFrame([dict(a=1, b=2)])
    with ExitStack() as stack:
        stack.enter_context(mock.patch("dtale.app.JUPYTER_SERVER_PROXY", True))
        mock_run = stack.enter_context(
            mock.patch("dtale.app.DtaleFlask.run", mock.Mock()))
        stack.enter_context(
            mock.patch("dtale.app.is_up", mock.Mock(return_value=False)))
        stack.enter_context(mock.patch("dtale.app.ACTIVE_PORT", 40000))
        stack.enter_context(mock.patch("dtale.app.ACTIVE_HOST", "localhost"))
        mock_requests = stack.enter_context(
            mock.patch("requests.get", mock.Mock()))
        instance = show(data=test_data,
                        subprocess=False,
                        name="foo",
                        ignore_duplicate=True)
        assert "/user/{}/proxy/40000".format(
            getpass.getuser()) == instance._url
        mock_run.assert_called_once()

        pdt.assert_frame_equal(instance.data, test_data)
        tmp = test_data.copy()
        tmp["biz"] = 2.5
        instance.data = tmp
        unittest.assertEqual(
            global_state.DTYPES[instance._data_id],
            views.build_dtypes_state(tmp),
            "should update app data/dtypes",
        )

        instance2 = get_instance(instance._data_id)
        assert instance2._url == instance._url
        instances()

        assert get_instance(
            20) is None  # should return None for invalid data ids

        instance.kill()
        mock_requests.assert_called_once()
        assert mock_requests.call_args[0][
            0] == "/user/{}/proxy/40000/shutdown".format(getpass.getuser())
        assert global_state.METADATA["1"]["name"] == "foo"

    with ExitStack() as stack:
        stack.enter_context(mock.patch("dtale.app.JUPYTER_SERVER_PROXY", True))
        mock_run = stack.enter_context(
            mock.patch("dtale.app.DtaleFlask.run", mock.Mock()))
        stack.enter_context(
            mock.patch("dtale.app.is_up", mock.Mock(return_value=False)))
        mock_requests = stack.enter_context(
            mock.patch("requests.get", mock.Mock()))
        instance = show(
            data=test_data,
            subprocess=False,
            ignore_duplicate=True,
            app_root="/custom_root/",
        )
        assert "/custom_root/40000" == instance._url
        mock_run.assert_called_once()

        instance2 = get_instance(instance._data_id)
        # this is a known bug where get_instance will not work if you've specified an `app_root' in show()
        assert not instance2._url == instance._url
        instances()
        instance.kill()
        mock_requests.assert_called_once()
        assert mock_requests.call_args[0][0] == "/custom_root/40000/shutdown"
Exemple #26
0
def gc(
    self,
    all_branches=False,
    cloud=False,
    remote=None,
    with_deps=False,
    all_tags=False,
    all_commits=False,
    all_experiments=False,
    force=False,
    jobs=None,
    repos=None,
    workspace=False,
):

    # require `workspace` to be true to come into effect.
    # assume `workspace` to be enabled if any of `all_tags`, `all_commits`,
    # `all_experiments` or `all_branches` are enabled.
    _raise_error_if_all_disabled(
        workspace=workspace,
        all_tags=all_tags,
        all_commits=all_commits,
        all_branches=all_branches,
        all_experiments=all_experiments,
    )

    from contextlib import ExitStack

    from dvc.repo import Repo

    if not repos:
        repos = []
    all_repos = [Repo(path) for path in repos]

    used_objs: Set["HashFile"] = set()
    with ExitStack() as stack:
        for repo in all_repos:
            stack.enter_context(repo.lock)

        for repo in all_repos + [self]:
            for objs in repo.used_objs(
                    all_branches=all_branches,
                    with_deps=with_deps,
                    all_tags=all_tags,
                    all_commits=all_commits,
                    all_experiments=all_experiments,
                    remote=remote,
                    force=force,
                    jobs=jobs,
            ).values():
                used_objs.update(objs)

    for scheme, odb in self.odb.by_scheme():
        if not odb:
            continue

        removed = odb.gc(
            {obj
             for obj in used_objs if obj.fs.scheme == scheme},
            jobs=jobs,
        )
        if not removed:
            logger.info(f"No unused '{scheme}' cache to remove.")

    if not cloud:
        return

    remote = self.cloud.get_remote(remote, "gc -c")
    removed = remote.gc(
        {obj
         for obj in used_objs if obj.fs.scheme == Schemes.LOCAL},
        jobs=jobs,
    )
    if not removed:
        logger.info("No unused cache to remove from remote.")
Exemple #27
0
 def __init__(self, mgrs):
     self.default, self.stack = L(mgrs), ExitStack()
Exemple #28
0
    def decode_and_evaluate(self,
                            output_name: Optional[str] = None
                            ) -> Dict[str, float]:
        """
        Decodes data set and evaluates given a checkpoint.

        :param output_name: Filename to write translations to. If None, will not write outputs.
        :return: Mapping of metric names to scores.
        """

        # 1. Translate
        trans_wall_time = 0.0
        translations = []  # type: List[List[str]]
        with ExitStack() as exit_stack:
            outputs = [
                exit_stack.enter_context(
                    data_io.smart_open(output_name.format(
                        factor=idx), 'w')) if output_name is not None else None
                for idx in range(self.model.num_target_factors)
            ]

            tic = time.time()
            trans_inputs = []  # type: List[inference.TranslatorInput]
            for i, inputs in enumerate(self.inputs_sentences):
                trans_inputs.append(
                    sockeye.inference.make_input_from_multiple_strings(
                        i, inputs))
            trans_outputs = self.translator.translate(trans_inputs)
            trans_wall_time = time.time() - tic
            for trans_input, trans_output in zip(trans_inputs, trans_outputs):
                output_strings = [trans_output.translation]
                if trans_output.factor_translations is not None and len(
                        outputs) > 1:
                    output_strings += trans_output.factor_translations
                translations.append(output_strings)
                for output_string, output_file in zip(output_strings, outputs):
                    if output_file is not None:
                        print(output_string, file=output_file)
        avg_time = trans_wall_time / len(self.targets_sentences[0])
        translations = list(zip(*translations))

        # 2. Evaluate
        metrics = {
            C.BLEU:
            evaluate.raw_corpus_bleu(hypotheses=translations[0],
                                     references=self.targets_sentences[0],
                                     offset=0.01),
            C.CHRF:
            evaluate.raw_corpus_chrf(hypotheses=translations[0],
                                     references=self.targets_sentences[0]),
            C.ROUGE1:
            evaluate.raw_corpus_rouge1(hypotheses=translations[0],
                                       references=self.targets_sentences[0]),
            C.ROUGE2:
            evaluate.raw_corpus_rouge2(hypotheses=translations[0],
                                       references=self.targets_sentences[0]),
            C.ROUGEL:
            evaluate.raw_corpus_rougel(hypotheses=translations[0],
                                       references=self.targets_sentences[0]),
            C.LENRATIO:
            evaluate.raw_corpus_length_ratio(
                hypotheses=translations[0],
                references=self.targets_sentences[0]),
            C.AVG_TIME:
            avg_time,
            C.DECODING_TIME:
            trans_wall_time
        }

        if len(translations) > 1:  # metrics for other target factors
            for i, _ in enumerate(translations[1:], 1):
                # only BLEU
                metrics.update({
                    'f%d-%s' % (i, C.BLEU):
                    evaluate.raw_corpus_bleu(
                        hypotheses=translations[i],
                        references=self.targets_sentences[i],
                        offset=0.01)
                })
        return metrics
Exemple #29
0
    def _run_command(self, args, cwd=None, dest_dir=None, tool_name=None):
        if dest_dir is None:
            dest_dir = self.log_dir

        if tool_name is None:
            tool_name = self.tool_name

        os.makedirs(dest_dir, exist_ok=True)

        out_path = os.path.join(dest_dir, f"{tool_name}.out")
        cmdline_path = os.path.join(dest_dir, f"{tool_name}.cmdline")

        is_device_monitoring_enabled = config.is_device_monitoring_enabled()

        cmdline = shlex_join(args)

        with open(cmdline_path, "w") as fd:
            fd.write(cmdline)
            fd.write("\n")

        with ExitStack() as stack:
            print(cmdline)
            print(f"Output will be written to {out_path}")
            print()

            if is_device_monitoring_enabled:
                diskstats_before_path = save_diskstats(dest_dir, "BEFORE")

            save_mpool_info(dest_dir, "BEFORE")

            fd = open(out_path, "w")
            stack.enter_context(fd)

            end_timestamp_ms = None
            start_timestamp_ms = int(time.time() * 1000)

            if self.start_timestamp_ms is None:
                self.start_timestamp_ms = start_timestamp_ms

            vmstat_proc, vmstat_fd = spawn_vmstat(dest_dir)
            stack.enter_context(vmstat_fd)
            stack.enter_context(vmstat_proc)

            if is_device_monitoring_enabled:
                iostat_proc, iostat_fd = spawn_iostat(dest_dir)
                stack.enter_context(iostat_fd)
                stack.enter_context(iostat_proc)

            #
            # now launching the main process
            #
            proc = subprocess.Popen(args,
                                    cwd=cwd,
                                    stdout=fd,
                                    stderr=subprocess.STDOUT)
            stack.enter_context(proc)
            pid = proc.pid

            pidstat_proc, pidstat_fd = spawn_pidstat(dest_dir, pid)
            stack.enter_context(pidstat_fd)
            stack.enter_context(pidstat_proc)

            returncode = proc.wait()

            vmstat_proc.kill()

            if is_device_monitoring_enabled:
                iostat_proc.kill()

            if returncode != 0:
                raise Exception(
                    f"Command {args} failed with exit status {proc.returncode}"
                )

        end_timestamp_ms = int(time.time() * 1000)
        self.end_timestamp_ms = end_timestamp_ms

        save_mpool_info(dest_dir, "AFTER")

        if is_device_monitoring_enabled:
            diskstats_after_path = save_diskstats(dest_dir, "AFTER")

            diskstats_report = generate_diskstats_report(
                diskstats_before_path, diskstats_after_path)

            if self.start_diskstats_before_path is None:
                self.start_diskstats_before_path = diskstats_before_path

            full_run_diskstats_report = generate_diskstats_report(
                self.start_diskstats_before_path, diskstats_after_path)

            self.report["diskstats"] = full_run_diskstats_report["delta"]
        else:
            diskstats_report = None

        completed_info = CompletedCommand(
            start_timestamp_ms,
            end_timestamp_ms,
            out_path=out_path,
            diskstats=diskstats_report["delta"] if diskstats_report else None,
        )

        return completed_info
Exemple #30
0
    def horovod_train(self, model):
        # call setup after the ddp process has connected
        self.setup('fit')
        if self.is_function_implemented('setup', model):
            model.setup('fit')

        if torch.cuda.is_available() and self.on_gpu:
            # Horovod: pin GPU to local rank
            assert self.root_gpu == hvd.local_rank()
            torch.cuda.set_device(self.root_gpu)
            model.cuda(self.root_gpu)

        # avoid duplicating progress bar
        if hvd.rank() != 0 and self.progress_bar_callback is not None:
            self.progress_bar_callback.disable()

        # CHOOSE OPTIMIZER
        # allow for lr schedulers as well
        self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(
            model)

        # Horovod: scale the learning rate by the number of workers to account for
        # increased total batch size
        for optimizer in self.optimizers:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= hvd.size()

        # Horovod: adjust base LR used by schedulers to match scaled optimizer initial LR
        for scheduler in self.lr_schedulers:
            scheduler = scheduler['scheduler']
            if isinstance(scheduler, _LRScheduler):
                scheduler.base_lrs = [
                    lr * hvd.size() for lr in scheduler.base_lrs
                ]

        if self.use_amp:
            model, optimizers = model.configure_apex(amp, model,
                                                     self.optimizers,
                                                     self.amp_level)
            self.optimizers = optimizers
            self.reinit_scheduler_properties(self.optimizers,
                                             self.lr_schedulers)

        # Horovod: broadcast parameters & optimizer state to ensure consistent initialization
        hvd.broadcast_parameters(model.state_dict(), root_rank=0)
        for optimizer in self.optimizers:
            hvd.broadcast_optimizer_state(optimizer, root_rank=0)

        def filter_named_parameters(model, optimizer):
            opt_params = set([
                p for group in optimizer.param_groups
                for p in group.get('params', [])
            ])
            return [(name, p) for name, p in model.named_parameters()
                    if p in opt_params]

        # Horovod: wrap optimizers to perform gradient aggregation via allreduce
        self.optimizers = [
            hvd.DistributedOptimizer(optimizer,
                                     named_parameters=filter_named_parameters(
                                         model, optimizer))
            for optimizer in self.optimizers
        ]

        # Update logger rank info from Horovod to avoid race conditions from  different ranks
        # creating directories / writing files in the same locations.
        self.global_rank = hvd.rank()
        rank_zero_only.rank = self.global_rank

        with ExitStack() as stack:
            for optimizer in self.optimizers:
                # Synchronization will be performed explicitly following backward()
                stack.enter_context(optimizer.skip_synchronize())

            result = self.run_pretrain_routine(model)

        # Make sure all workers have finished training before returning to the user
        hvd.join()
        return result