コード例 #1
0
def run(DD_config, D_server_desc):

    if D_server_desc['hostname'] in ["szkmbp"]:
        D_server_desc['hostname'] = "localhost"

    rsconn = get_rsconn_with_timeout(D_server_desc,
                                     timeout=DD_config['database']['connection_setup_timeout'], wait_for_parameters_to_be_present=False)

    L_measurements = DD_config['database']['L_measurements']

    master_minibatch_size = DD_config['database']['master_minibatch_size']
    serialized_parameters_format = DD_config['database']['serialized_parameters_format']
    Ntrain = DD_config['model']['Ntrain']
    # Default behavior is to have no staleness, and perform ISGD from the moment that we
    # get values for all the importance weights. Until then, we do USGD.
    staleness_threshold_seconds = DD_config['database']['staleness_threshold_seconds']
    staleness_threshold_num_minibatches_master_processed = DD_config['database']['staleness_threshold_num_minibatches_master_processed']
    importance_weight_additive_constant = DD_config['database']['importance_weight_additive_constant']

    want_master_to_do_USGD_when_ISGD_is_not_possible = DD_config['database'].get('want_master_to_do_USGD_when_ISGD_is_not_possible', True)
    master_usable_importance_weights_threshold_to_ISGD = DD_config['database'].get('master_usable_importance_weights_threshold_to_ISGD', 1.0)
    master_routine = DD_config['model']['master_routine']
    if master_routine[0] != "sync_params":
        print "Error. Your master_routine should always start with 'sync_params'."
        print master_routine
        quit()
    turn_off_importance_sampling = DD_config["model"].get("turn_off_importance_sampling", False)



    # set up python logging system for logging_folder
    setup_python_logger(folder=DD_config["database"]["logging_folder"])
    logging.info(pprint.pformat(DD_config))

    remote_redis_logger = integration_distributed_training.server.logger.RedisLogger(rsconn, queue_prefix_identifier="service_master")
    model_api = ModelAPI(DD_config['model'])
    # This `record_machine_info` has to be called after the component that
    # makes use of theano if we hope to properly record the theano.config.
    integration_distributed_training.server.logger.record_machine_info(remote_redis_logger)

    # It's very important to determine if we're resuming from a previous run,
    # in which case we really want to load the paramters to resume training.
    if not check_if_parameters_are_present(rsconn):
        ### resuming_from_previous_run = False ###
        msg = "Starting a new run."
        remote_redis_logger.log('event', msg)
        logging.info(msg)
    else:
        ### resuming_from_previous_run = True ###

        msg = "Resuming from previous run."
        remote_redis_logger.log('event', msg)
        logging.info(msg)

        # This whole section is taken almost exactly from the service_worker.
        tic = time.time()
        current_parameters_str = rsconn.get("parameters:current")
        toc = time.time()
        remote_redis_logger.log('timing_profiler', {'sync_params_from_database' : (toc-tic)})

        if len(current_parameters_str) == 0:
            print "Error. No parameters found in the server."
            quit()

        if serialized_parameters_format == "opaque_string":
            tic = time.time()
            model_api.set_serialized_parameters(current_parameters_str)
            toc = time.time()
            remote_redis_logger.log('timing_profiler', {'model_api.set_serialized_parameters' : (toc-tic)})
            logging.info("The master has received initial parameters. This took %f seconds." % (toc - tic,))

        elif serialized_parameters_format == "ndarray_float32_tostring":
            parameters_current_timestamp_str = new_parameters_current_timestamp_str
            tic = time.time()
            model_api.set_serialized_parameters(current_parameters)
            toc = time.time()
            remote_redis_logger.log('timing_profiler', {'model_api.set_serialized_parameters' : (toc-tic)})
            logging.info("The master has received initial parameters. This took %f seconds." % (toc - tic,))

        else:
            logging.info("Fatal error : invalid serialized_parameters_format : %s." % serialized_parameters_format)
            quit()


    # Run just a simple test to make sure that the importance weights have been
    # set to something. In theory, there should always be valid values in there,
    # so this is just a sanity check.
    segment = "train"
    measurement = "individual_importance_weight"
    #nbr_of_present_importance_weights = rsconn.hlen("H_%s_minibatch_%s" % (segment, measurement))
    #assert 0 < nbr_of_present_importance_weights, "Error. The database should have been set up to have dummy importance weights at least."

    #print "Master found %d importance weights in the database." % nbr_of_present_importance_weights


    # The master splits its time between two tasks.
    #
    # (1) Publish the parameters back to the server,
    #     which triggers a cascade of re-evaluation of
    #     importance weights for every mi    sys.exit(0)

    # (2) Get samples representing training examples
    #     on which you perform training steps, taking into
    #     consideration all the things about the importance weights.
    #
    # Ultimately, the parameters must be shared, but it is
    # wasteful to do it at every training step. We have to find
    # the right balance.
    #
    # Task (1) should also be triggered on the first iteration
    # to initialize the parameters on the server before anything
    # else (that being said, the initial weights for all the batches
    # are 1.0, so things could start with Task (2) since the assistant
    # would start by resampling the indices.

    queue_name = "L_master_train_minibatch_indices_and_info_QUEUE"

    num_minibatches_master_processed_str = rsconn.get("parameters:num_minibatches_master_processed")
    if num_minibatches_master_processed_str is None or len(num_minibatches_master_processed_str) == 0:
        num_minibatches_master_processed = 0.0
        rsconn.set("parameters:num_minibatches_master_processed", num_minibatches_master_processed)
    else:
        num_minibatches_master_processed = float(num_minibatches_master_processed_str)

    print "num_minibatches_master_processed is %f" % num_minibatches_master_processed

    # The main loop runs until the user hits CTLR+C or until
    # the Helios cluster sends the SIGTERM to that process
    # five minutes before the end of training.
    def signal_handler(signal, frame):
        print "SIGTERM received for the first time."
        print "Will break from master main loop."
        print "Will make logger sync to database before terminating."
        print ""
        logging.info("Master received SIGTERM. %f, %s" % (time.time(), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))
        signal_handler.remote_redis_logger.log('event', "Master received SIGTERM.")
        signal_handler.remote_redis_logger.close()
        sys.exit(0)
    #
    signal.signal(signal.SIGINT, signal_handler)
    # I'm forced to use weird function properties because python
    # has stupid scoping rules.
    signal_handler.remote_redis_logger = remote_redis_logger

    # cache those values to use them for more than one computation
    D_importance_weights_and_more = None
    extra_statistics = None

    remote_redis_logger.log('event', "Before entering service_master main loop.")
    while True:

        for next_action in master_routine:
            assert next_action in [ "sync_params", "refresh_importance_weights", "process_minibatch", # the normal ones
                                    "wait_for_workers_to_update_all_the_importance_weights"]          # the special ones

            if next_action == "sync_params":

                remote_redis_logger.log('event', "sync_params")
                tic = time.time()
                if serialized_parameters_format == "opaque_string":
                    current_parameters_str = model_api.get_serialized_parameters()
                elif serialized_parameters_format == "ndarray_float32_tostring":
                    current_parameters_str = model_api.get_serialized_parameters().tostring(order='C')
                else:
                    logging.info("Fatal error : invalid serialized_parameters_format : %s." % serialized_parameters_format)
                    quit()
                toc = time.time()
                remote_redis_logger.log('timing_profiler', {'read_parameters_from_model' : (toc-tic)})

                tic = time.time()
                rsconn.set("parameters:current", current_parameters_str)
                rsconn.set("parameters:current_timestamp", time.time())
                rsconn.set("parameters:num_minibatches_master_processed", num_minibatches_master_processed)
                # potentially not used
                rsconn.set("parameters:current_datestamp", time.strftime("%Y-%m-%d %H:%M:%S"))
                toc = time.time()
                remote_redis_logger.log('timing_profiler', {'sync_params_to_database' : (toc-tic)})
                print "The master has updated the parameters."

            elif next_action == "wait_for_workers_to_update_all_the_importance_weights":

                # This next line is to ask for the master to wait until everything has been
                # updated. This can take a minute or so, and it's not a very good approach.
                # However, it's the way to see what would happen if we implemented ISGD exactly
                # without using any stale importance weights.
                #    wait_until_all_measurements_are_updated_by_workers(rsconn, "train", "importance_weight")
                logging.info("Error. Take the time to test wait_for_workers_to_update_all_the_importance_weights if you want to use it. I expect it to work, though.")

            elif next_action == "refresh_importance_weights":

                remote_redis_logger.log('event', "refresh_importance_weights")
                tic = time.time()
                _, D_importance_weights_and_more = get_raw_importance_weights(rsconn)

                (_, D_importance_weights_and_more, extra_statistics) = filter_raw_importance_weights(
                                                                            D_importance_weights_and_more,
                                                                            staleness_threshold_seconds=staleness_threshold_seconds,
                                                                            staleness_threshold_num_minibatches_master_processed=staleness_threshold_num_minibatches_master_processed,
                                                                            importance_weight_additive_constant=importance_weight_additive_constant,
                                                                            num_minibatches_master_processed=num_minibatches_master_processed)

                record_importance_weights_statistics(   D_importance_weights_and_more, extra_statistics,
                                            remote_redis_logger=remote_redis_logger, logging=logging,
                                            want_compute_entropy=True)

                toc = time.time()
                remote_redis_logger.log('timing_profiler', {'refresh_importance_weights' : (toc-tic)})
                #print "The master has obtained fresh importance weights."

            elif next_action == "process_minibatch":

                remote_redis_logger.log('event', "process_minibatch")
                #if A_importance_weights is None or nbr_of_usable_importance_weights is None:
                #    # nothing can be done here
                #    remote_redis_logger.log('event', "process_minibatch skipped")
                #    continue
                #else:
                #    remote_redis_logger.log('event', "process_minibatch")

                # Cause importance sampling to be done randomly if the value of
                # `turn_off_importance_sampling` is a floating-point value.
                # Note that another good approach would have been to alternate between
                # one mode and the other.
                if type(turn_off_importance_sampling) == float:
                    assert 0.0 <= turn_off_importance_sampling
                    assert turn_off_importance_sampling <= 1.0
                    if np.random.rand() <= turn_off_importance_sampling:
                        decision_to_turn_off_importance_sampling_this_iteration = True
                    else:
                        decision_to_turn_off_importance_sampling_this_iteration = False
                else:
                    assert type(turn_off_importance_sampling) == bool
                    decision_to_turn_off_importance_sampling_this_iteration = turn_off_importance_sampling

                tic = time.time()
                (intent, mode, A_sampled_indices, A_scaling_factors) = sample_indices_and_scaling_factors(
                        D_importance_weights_and_more=D_importance_weights_and_more,
                        extra_statistics=extra_statistics,
                        nbr_samples=master_minibatch_size,
                        master_usable_importance_weights_threshold_to_ISGD=master_usable_importance_weights_threshold_to_ISGD,
                        want_master_to_do_USGD_when_ISGD_is_not_possible=want_master_to_do_USGD_when_ISGD_is_not_possible,
                        turn_off_importance_sampling=decision_to_turn_off_importance_sampling_this_iteration)
                toc = time.time()
                remote_redis_logger.log('timing_profiler', {'sample_indices_and_scaling_factors' : (toc-tic)})

                if intent == 'wait_and_retry':
                    remote_redis_logger.log(['event', "Master does not have enough importance weights to do ISGD, and doesn't want to default to USGD. Sleeping for 2 seconds."])
                    time.sleep(2.0)
                    continue

                if intent == 'proceed':
                    remote_redis_logger.log('event', "Master proceeding with round of %s." % (mode,))
                    tic = time.time()

                    #if not np.all(np.isfinite(A_scaling_factors)):
                    #    import pdb; pdb.set_trace()

                    model_api.master_process_minibatch(A_sampled_indices, A_scaling_factors, "train")
                    toc = time.time()
                    remote_redis_logger.log('timing_profiler', {'master_process_minibatch' : (toc-tic), 'mode':mode})
                    logging.info("The master has processed a minibatch using %s. %f, %s" % (mode, time.time(), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))
                    num_minibatches_master_processed += 1


    remote_redis_logger.log('event', "Master exited from main loop")
    logging.info("Master exited from main loop. %f, %s" % (time.time(), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))
    remote_redis_logger.close()
    quit()
コード例 #2
0
def run(DD_config, rserv, rsconn, bootstrap_file, D_server_desc):

    importance_weight_additive_constant = DD_config['database'][
        'importance_weight_additive_constant']

    # set up logging system for logging_folder
    setup_python_logger(folder=DD_config["database"]["logging_folder"])
    logging.info(pprint.pformat(DD_config))

    # set up logging system to the redis server
    remote_redis_logger = integration_distributed_training.server.logger.RedisLogger(
        rsconn, queue_prefix_identifier="service_database")
    integration_distributed_training.server.logger.record_machine_info(
        remote_redis_logger)

    if check_if_any_initialization_has_even_been_done(rsconn):
        #rsconn.set("resuming_from_previous_run", True)
        msg = "Resuming from previous run."
        remote_redis_logger.log('event', msg)
        logging.info(msg)

        # So we're about to resume training. Just as a sanity check, though, we should
        # really 1) make sure that the parameters are all there
        #        2) the contents of "L_workers_%s_minibatch_indices_ALL" is used to populate "L_workers_%s_minibatch_indices_QUEUE"
        #

        if not check_if_parameters_are_present(rsconn):
            msg = "Error. We are supposed to be resuming from a previous training session, but the parameters are not found in the database."
            remote_redis_logger.log('event', msg)
            logging.info(msg)
            quit()

        refresh_QUEUE_from_ALL(rsconn, DD_config['database']['L_segments'],
                               remote_redis_logger, logging)
        set_initialization_as_done(rsconn, D_server_desc)

    else:

        configure(rsconn,
                  Ntrain=DD_config['model']['Ntrain'],
                  Nvalid=DD_config['model']['Nvalid'],
                  Ntest=DD_config['model']['Ntest'],
                  **DD_config['database'])
        #rsconn.set("resuming_from_previous_run", False)
        msg = "Starting a new run."
        remote_redis_logger.log('event', msg)
        logging.info(msg)

        set_initialization_as_done(rsconn, D_server_desc)

    # Use `rserv` to be able to shut down the
    # redis-server when the user hits CTRL+C.
    # Otherwise, the server is left in the background
    # and this can cause problems due to scripts
    # getting tangled together.

    def signal_handler(signal, frame):
        logging.info("You pressed CTRL+C.")
        logging.info("Closing the remote_redis_logger.")
        remote_redis_logger.log('event', "Received SIGTERM.")
        remote_redis_logger.close()
        logging.info("Sending save and shutdown commands to the redis-server.")
        rserv.stop(want_save=True)
        delete_bootstrap_file(bootstrap_file)
        sys.exit(0)

    signal.signal(signal.SIGINT, signal_handler)

    maximum_validation_accuracy = -1.0

    def sansgton(x):
        # sanitize singleton so it can be written to json.
        # Basically, it could be None, but you can't call float(None).
        # It can't be a np.float32 or np.float64 either, you can you have to call float(x)
        # on those values.
        if x is None:
            return float(np.nan)
        else:
            return float(x)

    remote_redis_logger.log('event',
                            "Before entering service_database main loop.")

    last_save_database_timestamp = None
    # save every 5 minutes (this should remove some pressure)
    save_database_threshold = 5 * 60

    while True:
        logging.info(
            "Running server. Press CTLR+C to stop. Timestamp %f. %s" %
            (time.time(),
             time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))
        logging.info(
            "Number minibatches processed by master    " +
            str(rsconn.get("parameters:num_minibatches_master_processed")))

        for segment in ["train", "valid", "test"]:
            logging.info("-- %s " % segment)
            for measurement in [
                    "individual_loss", "individual_accuracy",
                    "individual_gradient_square_norm"
            ]:
                (mean, variance, N,
                 r) = get_mean_variance_measurement_on_database(
                     rsconn, segment, measurement)
                std = np.sqrt(variance)
                if segment == "valid" and measurement == "individual_accuracy" and mean > maximum_validation_accuracy:
                    maximum_validation_accuracy = mean
                    logging.info(
                        "                                                                      ---Highest Validation Accuracy so Far---"
                    )
                logging.info(
                    "---- %s : mean %f, std %f    with %0.4f of values used." %
                    (measurement, mean, std, r))
                remote_redis_logger.log(
                    'measurement', {
                        'name': measurement,
                        'segment': segment,
                        'mean': sansgton(mean),
                        'std': sansgton(std),
                        'ratio_used': r
                    })

        logging.info("Highest Validation Accuracy seen so far " +
                     str(maximum_validation_accuracy))

        time.sleep(10.0)

        # TODO : Figure a way to factor the staleness into this whole thing making
        #        making a horrible mess of spaghetti.

        # This is just extra. We always do the computation with 0.0 in all cases.
        L_importance_weight_additive_constant = [
            importance_weight_additive_constant, 0.0, 0.0001, 0.001, 0.01, 0.1,
            1.0, 10.0
        ]

        (usgd2, staleisgd2, isgd2, mu2,
         ratio_of_usable_indices_for_USGD_and_ISGD,
         ratio_of_usable_indices_for_ISGDstale, nbr_minibatches,
         D_other_staleISGD_main_term) = get_trace_covariance_information(
             rsconn,
             "train",
             L_importance_weight_additive_constant=
             L_importance_weight_additive_constant)
        # Make sure that you have a reasonable number of readings before
        # reporting those statistics.
        if 0.1 <= ratio_of_usable_indices_for_USGD_and_ISGD:
            assert usgd2 is not None and isgd2 is not None and mu2 is not None
            logging.info(
                "Approximative norm squares of the mean gradient over whole dataset : %0.12f."
                % (mu2, ))
            logging.info("Trace(Cov USGD) without mu2 : %0.12f." % (usgd2, ))
            logging.info("Trace(Cov ISGD) without mu2: %0.12f." % (isgd2, ))
        else:
            logging.info(
                "ratio_of_usable_indices_for_USGD_and_ISGD %f not high enough to report those numbers"
                % ratio_of_usable_indices_for_USGD_and_ISGD)

        if 0.1 <= ratio_of_usable_indices_for_ISGDstale:
            logging.info("Trace(Cov Stale ISGD) without mu2 : %0.12f." %
                         (staleisgd2, ))
            for (k, v) in D_other_staleISGD_main_term.items():
                logging.info(
                    "Trace(Cov Stale ISGD) without mu2 : %0.12f.  Using importance_weight_additive_constant %f."
                    % (v, k))
        else:
            logging.info(
                "ratio_of_usable_indices_for_ISGDstale %f not high enough to report those numbers"
                % ratio_of_usable_indices_for_ISGDstale)
        logging.info("")

        remote_redis_logger.log(
            'SGD_trace_variance', {
                'approx_mu2':
                sansgton(mu2),
                'usgd2':
                sansgton(usgd2),
                'isgd2':
                sansgton(isgd2),
                'staleisgd2':
                sansgton(staleisgd2),
                'extra_staleisgd2':
                D_other_staleISGD_main_term,
                'ratio_of_usable_indices_for_USGD_and_ISGD':
                ratio_of_usable_indices_for_USGD_and_ISGD,
                'ratio_of_usable_indices_for_ISGDstale':
                ratio_of_usable_indices_for_ISGDstale
            })

        time.sleep(10.0)
        logging.info("")

        # have the database save itself to the file at every iteration through the loop
        if DD_config['database']['want_rdb_background_save']:
            # redundant logic, but clearer to read
            if ((last_save_database_timestamp is None)
                    or (save_database_threshold <=
                        (time.time() - last_save_database_timestamp))):
                last_save_database_timestamp = time.time()
                rsconn.bgsave()
                logging.info("Database called rsconn.bgsave(). %f. %s" %
                             (time.time(),
                              time.strftime('%Y-%m-%d %H:%M:%S',
                                            time.localtime(time.time()))))
                continue
コード例 #3
0
def run(DD_config, rserv, rsconn, bootstrap_file, D_server_desc):

    importance_weight_additive_constant = DD_config["database"]["importance_weight_additive_constant"]

    # set up logging system for logging_folder
    setup_python_logger(folder=DD_config["database"]["logging_folder"])
    logging.info(pprint.pformat(DD_config))

    # set up logging system to the redis server
    remote_redis_logger = integration_distributed_training.server.logger.RedisLogger(
        rsconn, queue_prefix_identifier="service_database"
    )
    integration_distributed_training.server.logger.record_machine_info(remote_redis_logger)

    if check_if_any_initialization_has_even_been_done(rsconn):
        # rsconn.set("resuming_from_previous_run", True)
        msg = "Resuming from previous run."
        remote_redis_logger.log("event", msg)
        logging.info(msg)

        # So we're about to resume training. Just as a sanity check, though, we should
        # really 1) make sure that the parameters are all there
        #        2) the contents of "L_workers_%s_minibatch_indices_ALL" is used to populate "L_workers_%s_minibatch_indices_QUEUE"
        #

        if not check_if_parameters_are_present(rsconn):
            msg = "Error. We are supposed to be resuming from a previous training session, but the parameters are not found in the database."
            remote_redis_logger.log("event", msg)
            logging.info(msg)
            quit()

        refresh_QUEUE_from_ALL(rsconn, DD_config["database"]["L_segments"], remote_redis_logger, logging)
        set_initialization_as_done(rsconn, D_server_desc)

    else:

        configure(
            rsconn,
            Ntrain=DD_config["model"]["Ntrain"],
            Nvalid=DD_config["model"]["Nvalid"],
            Ntest=DD_config["model"]["Ntest"],
            **DD_config["database"]
        )
        # rsconn.set("resuming_from_previous_run", False)
        msg = "Starting a new run."
        remote_redis_logger.log("event", msg)
        logging.info(msg)

        set_initialization_as_done(rsconn, D_server_desc)

    # Use `rserv` to be able to shut down the
    # redis-server when the user hits CTRL+C.
    # Otherwise, the server is left in the background
    # and this can cause problems due to scripts
    # getting tangled together.

    def signal_handler(signal, frame):
        logging.info("You pressed CTRL+C.")
        logging.info("Closing the remote_redis_logger.")
        remote_redis_logger.log("event", "Received SIGTERM.")
        remote_redis_logger.close()
        logging.info("Sending save and shutdown commands to the redis-server.")
        rserv.stop(want_save=True)
        delete_bootstrap_file(bootstrap_file)
        sys.exit(0)

    signal.signal(signal.SIGINT, signal_handler)

    maximum_validation_accuracy = -1.0

    def sansgton(x):
        # sanitize singleton so it can be written to json.
        # Basically, it could be None, but you can't call float(None).
        # It can't be a np.float32 or np.float64 either, you can you have to call float(x)
        # on those values.
        if x is None:
            return float(np.nan)
        else:
            return float(x)

    remote_redis_logger.log("event", "Before entering service_database main loop.")

    last_save_database_timestamp = None
    # save every 5 minutes (this should remove some pressure)
    save_database_threshold = 5 * 60

    while True:
        logging.info(
            "Running server. Press CTLR+C to stop. Timestamp %f. %s"
            % (time.time(), time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())))
        )
        logging.info(
            "Number minibatches processed by master    "
            + str(rsconn.get("parameters:num_minibatches_master_processed"))
        )

        for segment in ["train", "valid", "test"]:
            logging.info("-- %s " % segment)
            for measurement in ["individual_loss", "individual_accuracy", "individual_gradient_square_norm"]:
                (mean, variance, N, r) = get_mean_variance_measurement_on_database(rsconn, segment, measurement)
                std = np.sqrt(variance)
                if segment == "valid" and measurement == "individual_accuracy" and mean > maximum_validation_accuracy:
                    maximum_validation_accuracy = mean
                    logging.info(
                        "                                                                      ---Highest Validation Accuracy so Far---"
                    )
                logging.info("---- %s : mean %f, std %f    with %0.4f of values used." % (measurement, mean, std, r))
                remote_redis_logger.log(
                    "measurement",
                    {
                        "name": measurement,
                        "segment": segment,
                        "mean": sansgton(mean),
                        "std": sansgton(std),
                        "ratio_used": r,
                    },
                )

        logging.info("Highest Validation Accuracy seen so far " + str(maximum_validation_accuracy))

        time.sleep(10.0)

        # TODO : Figure a way to factor the staleness into this whole thing making
        #        making a horrible mess of spaghetti.

        # This is just extra. We always do the computation with 0.0 in all cases.
        L_importance_weight_additive_constant = [
            importance_weight_additive_constant,
            0.0,
            0.0001,
            0.001,
            0.01,
            0.1,
            1.0,
            10.0,
        ]

        (
            usgd2,
            staleisgd2,
            isgd2,
            mu2,
            ratio_of_usable_indices_for_USGD_and_ISGD,
            ratio_of_usable_indices_for_ISGDstale,
            nbr_minibatches,
            D_other_staleISGD_main_term,
        ) = get_trace_covariance_information(
            rsconn, "train", L_importance_weight_additive_constant=L_importance_weight_additive_constant
        )
        # Make sure that you have a reasonable number of readings before
        # reporting those statistics.
        if 0.1 <= ratio_of_usable_indices_for_USGD_and_ISGD:
            assert usgd2 is not None and isgd2 is not None and mu2 is not None
            logging.info("Approximative norm squares of the mean gradient over whole dataset : %0.12f." % (mu2,))
            logging.info("Trace(Cov USGD) without mu2 : %0.12f." % (usgd2,))
            logging.info("Trace(Cov ISGD) without mu2: %0.12f." % (isgd2,))
        else:
            logging.info(
                "ratio_of_usable_indices_for_USGD_and_ISGD %f not high enough to report those numbers"
                % ratio_of_usable_indices_for_USGD_and_ISGD
            )

        if 0.1 <= ratio_of_usable_indices_for_ISGDstale:
            logging.info("Trace(Cov Stale ISGD) without mu2 : %0.12f." % (staleisgd2,))
            for (k, v) in D_other_staleISGD_main_term.items():
                logging.info(
                    "Trace(Cov Stale ISGD) without mu2 : %0.12f.  Using importance_weight_additive_constant %f."
                    % (v, k)
                )
        else:
            logging.info(
                "ratio_of_usable_indices_for_ISGDstale %f not high enough to report those numbers"
                % ratio_of_usable_indices_for_ISGDstale
            )
        logging.info("")

        remote_redis_logger.log(
            "SGD_trace_variance",
            {
                "approx_mu2": sansgton(mu2),
                "usgd2": sansgton(usgd2),
                "isgd2": sansgton(isgd2),
                "staleisgd2": sansgton(staleisgd2),
                "extra_staleisgd2": D_other_staleISGD_main_term,
                "ratio_of_usable_indices_for_USGD_and_ISGD": ratio_of_usable_indices_for_USGD_and_ISGD,
                "ratio_of_usable_indices_for_ISGDstale": ratio_of_usable_indices_for_ISGDstale,
            },
        )

        time.sleep(10.0)
        logging.info("")

        # have the database save itself to the file at every iteration through the loop
        if DD_config["database"]["want_rdb_background_save"]:
            # redundant logic, but clearer to read
            if (last_save_database_timestamp is None) or (
                save_database_threshold <= (time.time() - last_save_database_timestamp)
            ):
                last_save_database_timestamp = time.time()
                rsconn.bgsave()
                logging.info(
                    "Database called rsconn.bgsave(). %f. %s"
                    % (time.time(), time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())))
                )
                continue