def run(DD_config, D_server_desc):

    # TODO : Get rid of this cheap hack to circumvent my OSX's inability to see itself.
    if D_server_desc['hostname'] in ["szkmbp"]:
        D_server_desc['hostname'] = "localhost"

    rsconn = get_rsconn_with_timeout(D_server_desc,
                                     timeout=DD_config['database']['connection_setup_timeout'], wait_for_parameters_to_be_present=True)


    num_importance_weight_batches_processed = 0

    L_measurements = DD_config['database']['L_measurements']
    serialized_parameters_format = DD_config['database']['serialized_parameters_format']
    worker_routine = DD_config['model']['worker_routine']
    if worker_routine[0] != "sync_params":
        print "Error. Your worker_routine should always start with 'sync_params'."
        print worker_routine
        quit()

    remote_redis_logger = integration_distributed_training.server.logger.RedisLogger(rsconn, queue_prefix_identifier="service_worker")

    def signal_handler(signal, frame):
        print "You pressed CTRL+C."
        print "Closing the remote_redis_logger."
        remote_redis_logger.log('event', "Received SIGTERM.")
        remote_redis_logger.close()
        sys.exit(0)

    signal.signal(signal.SIGINT, signal_handler)

    model_api = ModelAPI(DD_config['model'])
    # This `record_machine_info` has to be called after the component that
    # makes use of theano if we hope to properly record the theano.config.
    integration_distributed_training.server.logger.record_machine_info(remote_redis_logger)


    #D_segment_priorities = {'train' : 50, 'valid' : 1, 'test' : 1}
    segment_priorities_p = np.array([20, 1, 1], dtype=np.float32)
    segment_priorities_p /= segment_priorities_p.sum()
    segment_priorities_p = segment_priorities_p.cumsum()
    segment_priorities_v = ['train', 'valid', 'test']
    def sample_segment():
        r = np.random.rand()
        for (i, e) in enumerate(segment_priorities_p):
            if r <= e:
                return segment_priorities_v[i]


    # The worker has to watch two things.
    #
    # (1) Have the parameters been updated on the server ?
    #     Check out the timestamp to determine if they have been updated.
    #     (Because of the assumption that the master updates the parameters
    #     and *then* the timestamp.)
    #     If they have been updated, we want to fetch a copy a convert it
    #     to a numpy array.
    #
    # (2) Process left-most entries from L_workers_%s_minibatch_indices_QUEUE
    #     according to their priorities given by `segment_priorities_p`.
    #
    # This will be useful for the workers when they want to stamp their importance weights
    # with a timestamp that reflects the last time that they got a fresh set of parameters.
    current_parameters = None
    parameters_current_timestamp_str = ""
    parameters_num_minibatches_master_processed_str = ""

    remote_redis_logger.log('event', "Before entering service_worker main loop.")
    while True:

        for next_action in worker_routine:
            assert next_action in [ "sync_params", "process_minibatch"]

            if next_action == "sync_params":

                remote_redis_logger.log('event', "sync_params")

                new_parameters_current_timestamp_str = rsconn.get("parameters:current_timestamp")

                # Note that `rsconn.get("parameters:num_minibatches_master_processed")` is reflected
                # right back to the database without us even parsing it at all. It's meant to be an alternate
                # way to specify the staleness.

                if parameters_current_timestamp_str != new_parameters_current_timestamp_str:
                    tic = time.time()
                    current_parameters_str = rsconn.get("parameters:current")
                    toc = time.time()
                    remote_redis_logger.log('timing_profiler', {'sync_params_from_database' : (toc-tic)})

                    if len(current_parameters_str) == 0:
                        print "Error. No parameters found in the server."
                        print "We could recover from this error by just ignoring it, and getting the parameters the next time around."
                        quit()

                    if serialized_parameters_format == "opaque_string":
                        parameters_current_timestamp_str = new_parameters_current_timestamp_str
                        parameters_num_minibatches_master_processed_str = rsconn.get("parameters:num_minibatches_master_processed")
                        tic = time.time()
                        model_api.set_serialized_parameters(current_parameters_str)
                        toc = time.time()
                        remote_redis_logger.log('timing_profiler', {'model_api.set_serialized_parameters' : (toc-tic)})
                        print "The worker has received new parameters. This took %f seconds." % (toc - tic,)
                        continue
                    elif serialized_parameters_format == "ndarray_float32_tostring":
                        current_parameters = np.fromstring(current_parameters_str, dtype=np.float32)
                        parameters_current_timestamp_str = new_parameters_current_timestamp_str
                        parameters_num_minibatches_master_processed_str = rsconn.get("parameters:num_minibatches_master_processed")
                        tic = time.time()
                        model_api.set_serialized_parameters(current_parameters)
                        toc = time.time()
                        remote_redis_logger.log('timing_profiler', {'model_api.set_serialized_parameters' : (toc-tic)})
                        print "The worker has received new parameters. This took %f seconds." % (toc - tic,)
                        continue
                    else:
                        print "Fatal error : invalid serialized_parameters_format : %s." % serialized_parameters_format
                        quit()

            elif next_action == "process_minibatch":

                remote_redis_logger.log('event', "process_minibatch")
                tic_process_minibatch = time.time()

                segment = sample_segment()
                queue_name = "L_workers_%s_minibatch_indices_QUEUE" % segment
                if rsconn.llen(queue_name) == 0:
                    msg = "The worker has nothing to do.\nThe queue %s is empty." % queue_name
                    print msg
                    remote_redis_logger.log('event', msg)
                    # TODO : Adjust the duration of the sleep.
                    time.sleep(0.2)
                    continue

                current_minibatch_indices_str = rsconn.lpop(queue_name)
                if current_minibatch_indices_str is None or len(current_minibatch_indices_str) == 0:
                    # This is very unexpected, because it implies that we have a queue
                    # that is shorter than the number of workers. It's not illegal, but
                    # just generally not recommended for a setup.
                    msg = "The worker has nothing to do.\nIt is as though queue %s was empty when we tried to pop an element from the left." % queue_name
                    print msg
                    remote_redis_logger.log('event', msg)
                    # TODO : Adjust the duration of the sleep.
                    time.sleep(0.2)
                    continue


                current_minibatch_indices = np.fromstring(current_minibatch_indices_str, dtype=np.int32)

                # There is a special thing to do with the individual_importance_weight.
                # We want to keep around their previous values.

                tic = time.time()
                tmp_str = rsconn.hget("H_%s_minibatch_%s" % (segment, "individual_importance_weight"), current_minibatch_indices_str)
                rsconn.hset("H_%s_minibatch_%s" % (segment, "previous_individual_importance_weight"), current_minibatch_indices_str, tmp_str)
                toc = time.time()
                remote_redis_logger.log('timing_profiler', {'copied individual_importance_weight to previous_individual_importance_weight' : (toc-tic)})

                tic = time.time()
                # This returns a dictionary of numpy arrays.
                DA_measurements = model_api.worker_process_minibatch(current_minibatch_indices, segment, L_measurements)
                toc = time.time()
                remote_redis_logger.log('timing_profiler', {'worker_process_minibatch' : (toc-tic)})

                tic_send_measurements_to_database = time.time()
                # Update the measurements. Update the timestamps.
                # None of the measurements should be missing.
                for measurement in L_measurements:

                    A_values = DA_measurements[measurement]
                    assert type(A_values) == np.ndarray, "Your `worker_process_minibatch` function is supposed to return an array of np.float32 as measurements (%s), but now those values are not even numpy arrays. They are %s instead." % (measurement, type(A_values))
                    if A_values.dtype == np.float64:
                        # This conversion is acceptable.
                        A_values = A_values.astype(np.float32)
                    assert A_values.dtype == np.float32, "Your `worker_process_minibatch` function is supposed to return an array of np.float32 as measurements (%s), but now that array has dtype %s instead." % (measurement, A_values.dtype)

                    number_of_invalid_values = np.logical_not(np.isfinite(A_values)).sum()
                    if 0 < number_of_invalid_values:
                        msg = "FATAL ERROR. You have %d invalid values returned for %s." % (number_of_invalid_values, measurement)
                        print msg
                        print A_values
                        remote_redis_logger.log('event', msg)
                        quit()
                        #print "Starting debugger."
                        #import pdb; pdb.set_trace()

                    rsconn.hset("H_%s_minibatch_%s" % (segment, measurement), current_minibatch_indices_str, A_values.tostring(order='C'))

                    previous_update_timestamp_str = rsconn.hget("H_%s_minibatch_%s_measurement_last_update_timestamp" % (segment, measurement), current_minibatch_indices_str)
                    if previous_update_timestamp_str is None or len(previous_update_timestamp_str) == 0:
                        print "The measurements are supposed to be initialized when starting the database."
                        print "They are supposed to have a timestamp set at that time."
                        print "This is not a serious error from which we could not recover, but it signals that there is a bug, so let's quit() here."
                        quit()
                        #previous_update_timestamp = 0.0
                    else:
                        previous_update_timestamp = float(previous_update_timestamp_str)

                    print "(%s, %s) timestamp delta between updates to that measurement : %f" % (segment, measurement, time.time() - previous_update_timestamp, )

                    current_update_timestamp = time.time()
                    rsconn.hset("H_%s_minibatch_%s_measurement_last_update_timestamp" % (segment, measurement), current_minibatch_indices_str, current_update_timestamp)

                    # This string is reflected intact to the database. It does not get parsed. The usefulness of this comes
                    # from the fact that the master can then how many "minibatches ago" the parameters came.
                    # This is basically the same thing as the "delay_between_measurement_update_and_parameter_update"
                    # below, but it's an absolute value instead of being a difference.
                    rsconn.hset("H_%s_minibatch_%s_measurement_num_minibatches_master_processed" % (segment, measurement), current_minibatch_indices_str, parameters_num_minibatches_master_processed_str)

                    delay_between_measurement_update = current_update_timestamp - previous_update_timestamp
                    delay_between_measurement_update_and_parameter_update = current_update_timestamp - float(parameters_current_timestamp_str)

                    rsconn.hset("H_%s_minibatch_%s_delay_between_measurement_update" % (segment, measurement), current_minibatch_indices_str, delay_between_measurement_update)
                    rsconn.hset("H_%s_minibatch_%s_delay_between_measurement_update_and_parameter_update" % (segment, measurement), current_minibatch_indices_str, delay_between_measurement_update_and_parameter_update)


                    #print "delay_between_measurement_update : %f seconds" % delay_between_measurement_update
                    #print "delay_between_measurement_update_and_parameter_update : %f seconds" % delay_between_measurement_update_and_parameter_update

                    # Be careful. If you re-indent the next block deeper,
                    # you'll mess up everything with the re-queuing of the minibatches.

                toc_send_measurements_to_database = time.time()
                remote_redis_logger.log('timing_profiler', {'send_measurements_to_database' : (toc_send_measurements_to_database-tic_send_measurements_to_database)})
                # We could log this for every measurement, but we'll just log it for one of them.
                # Otherwise, this is multiplying the messaging without real need.
                # We'll use the values for the last measurement, which outlasts the for loop above
                # due to shitty python scoping.
                remote_redis_logger.log('delay', {'delay_between_measurement_update' : delay_between_measurement_update, 'delay_between_measurement_update_and_parameter_update':delay_between_measurement_update_and_parameter_update})


                # Push back that minibatch to the right of the queue.
                # It will eventually find its way back to some worker,
                # but we will cover all the other ones before that happens.
                rsconn.rpush(queue_name, current_minibatch_indices_str)
                toc_process_minibatch = time.time()
                remote_redis_logger.log('timing_profiler', {'process_minibatch' : (toc_process_minibatch-tic_process_minibatch)})
                msg = "Processed one minibatch from %s. Pushed back to back of the line. Total time taken is %f seconds." % (segment, toc_process_minibatch-tic_process_minibatch)
                print msg
                remote_redis_logger.log('event', msg)
Exemplo n.º 2
0
def run(DD_config, D_server_desc):

    if D_server_desc['hostname'] in ["szkmbp"]:
        D_server_desc['hostname'] = "localhost"

    rsconn = get_rsconn_with_timeout(D_server_desc,
                                     timeout=DD_config['database']['connection_setup_timeout'], wait_for_parameters_to_be_present=False)

    L_measurements = DD_config['database']['L_measurements']

    master_minibatch_size = DD_config['database']['master_minibatch_size']
    serialized_parameters_format = DD_config['database']['serialized_parameters_format']
    Ntrain = DD_config['model']['Ntrain']
    # Default behavior is to have no staleness, and perform ISGD from the moment that we
    # get values for all the importance weights. Until then, we do USGD.
    staleness_threshold_seconds = DD_config['database']['staleness_threshold_seconds']
    staleness_threshold_num_minibatches_master_processed = DD_config['database']['staleness_threshold_num_minibatches_master_processed']
    importance_weight_additive_constant = DD_config['database']['importance_weight_additive_constant']

    want_master_to_do_USGD_when_ISGD_is_not_possible = DD_config['database'].get('want_master_to_do_USGD_when_ISGD_is_not_possible', True)
    master_usable_importance_weights_threshold_to_ISGD = DD_config['database'].get('master_usable_importance_weights_threshold_to_ISGD', 1.0)
    master_routine = DD_config['model']['master_routine']
    if master_routine[0] != "sync_params":
        print "Error. Your master_routine should always start with 'sync_params'."
        print master_routine
        quit()
    turn_off_importance_sampling = DD_config["model"].get("turn_off_importance_sampling", False)



    # set up python logging system for logging_folder
    setup_python_logger(folder=DD_config["database"]["logging_folder"])
    logging.info(pprint.pformat(DD_config))

    remote_redis_logger = integration_distributed_training.server.logger.RedisLogger(rsconn, queue_prefix_identifier="service_master")
    model_api = ModelAPI(DD_config['model'])
    # This `record_machine_info` has to be called after the component that
    # makes use of theano if we hope to properly record the theano.config.
    integration_distributed_training.server.logger.record_machine_info(remote_redis_logger)

    # It's very important to determine if we're resuming from a previous run,
    # in which case we really want to load the paramters to resume training.
    if not check_if_parameters_are_present(rsconn):
        ### resuming_from_previous_run = False ###
        msg = "Starting a new run."
        remote_redis_logger.log('event', msg)
        logging.info(msg)
    else:
        ### resuming_from_previous_run = True ###

        msg = "Resuming from previous run."
        remote_redis_logger.log('event', msg)
        logging.info(msg)

        # This whole section is taken almost exactly from the service_worker.
        tic = time.time()
        current_parameters_str = rsconn.get("parameters:current")
        toc = time.time()
        remote_redis_logger.log('timing_profiler', {'sync_params_from_database' : (toc-tic)})

        if len(current_parameters_str) == 0:
            print "Error. No parameters found in the server."
            quit()

        if serialized_parameters_format == "opaque_string":
            tic = time.time()
            model_api.set_serialized_parameters(current_parameters_str)
            toc = time.time()
            remote_redis_logger.log('timing_profiler', {'model_api.set_serialized_parameters' : (toc-tic)})
            logging.info("The master has received initial parameters. This took %f seconds." % (toc - tic,))

        elif serialized_parameters_format == "ndarray_float32_tostring":
            parameters_current_timestamp_str = new_parameters_current_timestamp_str
            tic = time.time()
            model_api.set_serialized_parameters(current_parameters)
            toc = time.time()
            remote_redis_logger.log('timing_profiler', {'model_api.set_serialized_parameters' : (toc-tic)})
            logging.info("The master has received initial parameters. This took %f seconds." % (toc - tic,))

        else:
            logging.info("Fatal error : invalid serialized_parameters_format : %s." % serialized_parameters_format)
            quit()


    # Run just a simple test to make sure that the importance weights have been
    # set to something. In theory, there should always be valid values in there,
    # so this is just a sanity check.
    segment = "train"
    measurement = "individual_importance_weight"
    #nbr_of_present_importance_weights = rsconn.hlen("H_%s_minibatch_%s" % (segment, measurement))
    #assert 0 < nbr_of_present_importance_weights, "Error. The database should have been set up to have dummy importance weights at least."

    #print "Master found %d importance weights in the database." % nbr_of_present_importance_weights


    # The master splits its time between two tasks.
    #
    # (1) Publish the parameters back to the server,
    #     which triggers a cascade of re-evaluation of
    #     importance weights for every mi    sys.exit(0)

    # (2) Get samples representing training examples
    #     on which you perform training steps, taking into
    #     consideration all the things about the importance weights.
    #
    # Ultimately, the parameters must be shared, but it is
    # wasteful to do it at every training step. We have to find
    # the right balance.
    #
    # Task (1) should also be triggered on the first iteration
    # to initialize the parameters on the server before anything
    # else (that being said, the initial weights for all the batches
    # are 1.0, so things could start with Task (2) since the assistant
    # would start by resampling the indices.

    queue_name = "L_master_train_minibatch_indices_and_info_QUEUE"

    num_minibatches_master_processed_str = rsconn.get("parameters:num_minibatches_master_processed")
    if num_minibatches_master_processed_str is None or len(num_minibatches_master_processed_str) == 0:
        num_minibatches_master_processed = 0.0
        rsconn.set("parameters:num_minibatches_master_processed", num_minibatches_master_processed)
    else:
        num_minibatches_master_processed = float(num_minibatches_master_processed_str)

    print "num_minibatches_master_processed is %f" % num_minibatches_master_processed

    # The main loop runs until the user hits CTLR+C or until
    # the Helios cluster sends the SIGTERM to that process
    # five minutes before the end of training.
    def signal_handler(signal, frame):
        print "SIGTERM received for the first time."
        print "Will break from master main loop."
        print "Will make logger sync to database before terminating."
        print ""
        logging.info("Master received SIGTERM. %f, %s" % (time.time(), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))
        signal_handler.remote_redis_logger.log('event', "Master received SIGTERM.")
        signal_handler.remote_redis_logger.close()
        sys.exit(0)
    #
    signal.signal(signal.SIGINT, signal_handler)
    # I'm forced to use weird function properties because python
    # has stupid scoping rules.
    signal_handler.remote_redis_logger = remote_redis_logger

    # cache those values to use them for more than one computation
    D_importance_weights_and_more = None
    extra_statistics = None

    remote_redis_logger.log('event', "Before entering service_master main loop.")
    while True:

        for next_action in master_routine:
            assert next_action in [ "sync_params", "refresh_importance_weights", "process_minibatch", # the normal ones
                                    "wait_for_workers_to_update_all_the_importance_weights"]          # the special ones

            if next_action == "sync_params":

                remote_redis_logger.log('event', "sync_params")
                tic = time.time()
                if serialized_parameters_format == "opaque_string":
                    current_parameters_str = model_api.get_serialized_parameters()
                elif serialized_parameters_format == "ndarray_float32_tostring":
                    current_parameters_str = model_api.get_serialized_parameters().tostring(order='C')
                else:
                    logging.info("Fatal error : invalid serialized_parameters_format : %s." % serialized_parameters_format)
                    quit()
                toc = time.time()
                remote_redis_logger.log('timing_profiler', {'read_parameters_from_model' : (toc-tic)})

                tic = time.time()
                rsconn.set("parameters:current", current_parameters_str)
                rsconn.set("parameters:current_timestamp", time.time())
                rsconn.set("parameters:num_minibatches_master_processed", num_minibatches_master_processed)
                # potentially not used
                rsconn.set("parameters:current_datestamp", time.strftime("%Y-%m-%d %H:%M:%S"))
                toc = time.time()
                remote_redis_logger.log('timing_profiler', {'sync_params_to_database' : (toc-tic)})
                print "The master has updated the parameters."

            elif next_action == "wait_for_workers_to_update_all_the_importance_weights":

                # This next line is to ask for the master to wait until everything has been
                # updated. This can take a minute or so, and it's not a very good approach.
                # However, it's the way to see what would happen if we implemented ISGD exactly
                # without using any stale importance weights.
                #    wait_until_all_measurements_are_updated_by_workers(rsconn, "train", "importance_weight")
                logging.info("Error. Take the time to test wait_for_workers_to_update_all_the_importance_weights if you want to use it. I expect it to work, though.")

            elif next_action == "refresh_importance_weights":

                remote_redis_logger.log('event', "refresh_importance_weights")
                tic = time.time()
                _, D_importance_weights_and_more = get_raw_importance_weights(rsconn)

                (_, D_importance_weights_and_more, extra_statistics) = filter_raw_importance_weights(
                                                                            D_importance_weights_and_more,
                                                                            staleness_threshold_seconds=staleness_threshold_seconds,
                                                                            staleness_threshold_num_minibatches_master_processed=staleness_threshold_num_minibatches_master_processed,
                                                                            importance_weight_additive_constant=importance_weight_additive_constant,
                                                                            num_minibatches_master_processed=num_minibatches_master_processed)

                record_importance_weights_statistics(   D_importance_weights_and_more, extra_statistics,
                                            remote_redis_logger=remote_redis_logger, logging=logging,
                                            want_compute_entropy=True)

                toc = time.time()
                remote_redis_logger.log('timing_profiler', {'refresh_importance_weights' : (toc-tic)})
                #print "The master has obtained fresh importance weights."

            elif next_action == "process_minibatch":

                remote_redis_logger.log('event', "process_minibatch")
                #if A_importance_weights is None or nbr_of_usable_importance_weights is None:
                #    # nothing can be done here
                #    remote_redis_logger.log('event', "process_minibatch skipped")
                #    continue
                #else:
                #    remote_redis_logger.log('event', "process_minibatch")

                # Cause importance sampling to be done randomly if the value of
                # `turn_off_importance_sampling` is a floating-point value.
                # Note that another good approach would have been to alternate between
                # one mode and the other.
                if type(turn_off_importance_sampling) == float:
                    assert 0.0 <= turn_off_importance_sampling
                    assert turn_off_importance_sampling <= 1.0
                    if np.random.rand() <= turn_off_importance_sampling:
                        decision_to_turn_off_importance_sampling_this_iteration = True
                    else:
                        decision_to_turn_off_importance_sampling_this_iteration = False
                else:
                    assert type(turn_off_importance_sampling) == bool
                    decision_to_turn_off_importance_sampling_this_iteration = turn_off_importance_sampling

                tic = time.time()
                (intent, mode, A_sampled_indices, A_scaling_factors) = sample_indices_and_scaling_factors(
                        D_importance_weights_and_more=D_importance_weights_and_more,
                        extra_statistics=extra_statistics,
                        nbr_samples=master_minibatch_size,
                        master_usable_importance_weights_threshold_to_ISGD=master_usable_importance_weights_threshold_to_ISGD,
                        want_master_to_do_USGD_when_ISGD_is_not_possible=want_master_to_do_USGD_when_ISGD_is_not_possible,
                        turn_off_importance_sampling=decision_to_turn_off_importance_sampling_this_iteration)
                toc = time.time()
                remote_redis_logger.log('timing_profiler', {'sample_indices_and_scaling_factors' : (toc-tic)})

                if intent == 'wait_and_retry':
                    remote_redis_logger.log(['event', "Master does not have enough importance weights to do ISGD, and doesn't want to default to USGD. Sleeping for 2 seconds."])
                    time.sleep(2.0)
                    continue

                if intent == 'proceed':
                    remote_redis_logger.log('event', "Master proceeding with round of %s." % (mode,))
                    tic = time.time()

                    #if not np.all(np.isfinite(A_scaling_factors)):
                    #    import pdb; pdb.set_trace()

                    model_api.master_process_minibatch(A_sampled_indices, A_scaling_factors, "train")
                    toc = time.time()
                    remote_redis_logger.log('timing_profiler', {'master_process_minibatch' : (toc-tic), 'mode':mode})
                    logging.info("The master has processed a minibatch using %s. %f, %s" % (mode, time.time(), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))
                    num_minibatches_master_processed += 1


    remote_redis_logger.log('event', "Master exited from main loop")
    logging.info("Master exited from main loop. %f, %s" % (time.time(), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))))
    remote_redis_logger.close()
    quit()
def run(DD_config, D_server_desc):

    # TODO : Get rid of this cheap hack to circumvent my OSX's inability to see itself.
    if D_server_desc['hostname'] in ["szkmbp"]:
        D_server_desc['hostname'] = "localhost"

    rsconn = get_rsconn_with_timeout(
        D_server_desc,
        timeout=DD_config['database']['connection_setup_timeout'],
        wait_for_parameters_to_be_present=True)

    num_importance_weight_batches_processed = 0

    L_measurements = DD_config['database']['L_measurements']
    serialized_parameters_format = DD_config['database'][
        'serialized_parameters_format']
    worker_routine = DD_config['model']['worker_routine']
    if worker_routine[0] != "sync_params":
        print "Error. Your worker_routine should always start with 'sync_params'."
        print worker_routine
        quit()

    remote_redis_logger = integration_distributed_training.server.logger.RedisLogger(
        rsconn, queue_prefix_identifier="service_worker")

    def signal_handler(signal, frame):
        print "You pressed CTRL+C."
        print "Closing the remote_redis_logger."
        remote_redis_logger.log('event', "Received SIGTERM.")
        remote_redis_logger.close()
        sys.exit(0)

    signal.signal(signal.SIGINT, signal_handler)

    model_api = ModelAPI(DD_config['model'])
    # This `record_machine_info` has to be called after the component that
    # makes use of theano if we hope to properly record the theano.config.
    integration_distributed_training.server.logger.record_machine_info(
        remote_redis_logger)

    #D_segment_priorities = {'train' : 50, 'valid' : 1, 'test' : 1}
    segment_priorities_p = np.array([20, 1, 1], dtype=np.float32)
    segment_priorities_p /= segment_priorities_p.sum()
    segment_priorities_p = segment_priorities_p.cumsum()
    segment_priorities_v = ['train', 'valid', 'test']

    def sample_segment():
        r = np.random.rand()
        for (i, e) in enumerate(segment_priorities_p):
            if r <= e:
                return segment_priorities_v[i]

    # The worker has to watch two things.
    #
    # (1) Have the parameters been updated on the server ?
    #     Check out the timestamp to determine if they have been updated.
    #     (Because of the assumption that the master updates the parameters
    #     and *then* the timestamp.)
    #     If they have been updated, we want to fetch a copy a convert it
    #     to a numpy array.
    #
    # (2) Process left-most entries from L_workers_%s_minibatch_indices_QUEUE
    #     according to their priorities given by `segment_priorities_p`.
    #
    # This will be useful for the workers when they want to stamp their importance weights
    # with a timestamp that reflects the last time that they got a fresh set of parameters.
    current_parameters = None
    parameters_current_timestamp_str = ""
    parameters_num_minibatches_master_processed_str = ""

    remote_redis_logger.log('event',
                            "Before entering service_worker main loop.")
    while True:

        for next_action in worker_routine:
            assert next_action in ["sync_params", "process_minibatch"]

            if next_action == "sync_params":

                remote_redis_logger.log('event', "sync_params")

                new_parameters_current_timestamp_str = rsconn.get(
                    "parameters:current_timestamp")

                # Note that `rsconn.get("parameters:num_minibatches_master_processed")` is reflected
                # right back to the database without us even parsing it at all. It's meant to be an alternate
                # way to specify the staleness.

                if parameters_current_timestamp_str != new_parameters_current_timestamp_str:
                    tic = time.time()
                    current_parameters_str = rsconn.get("parameters:current")
                    toc = time.time()
                    remote_redis_logger.log(
                        'timing_profiler',
                        {'sync_params_from_database': (toc - tic)})

                    if len(current_parameters_str) == 0:
                        print "Error. No parameters found in the server."
                        print "We could recover from this error by just ignoring it, and getting the parameters the next time around."
                        quit()

                    if serialized_parameters_format == "opaque_string":
                        parameters_current_timestamp_str = new_parameters_current_timestamp_str
                        parameters_num_minibatches_master_processed_str = rsconn.get(
                            "parameters:num_minibatches_master_processed")
                        tic = time.time()
                        model_api.set_serialized_parameters(
                            current_parameters_str)
                        toc = time.time()
                        remote_redis_logger.log('timing_profiler', {
                            'model_api.set_serialized_parameters': (toc - tic)
                        })
                        print "The worker has received new parameters. This took %f seconds." % (
                            toc - tic, )
                        continue
                    elif serialized_parameters_format == "ndarray_float32_tostring":
                        current_parameters = np.fromstring(
                            current_parameters_str, dtype=np.float32)
                        parameters_current_timestamp_str = new_parameters_current_timestamp_str
                        parameters_num_minibatches_master_processed_str = rsconn.get(
                            "parameters:num_minibatches_master_processed")
                        tic = time.time()
                        model_api.set_serialized_parameters(current_parameters)
                        toc = time.time()
                        remote_redis_logger.log('timing_profiler', {
                            'model_api.set_serialized_parameters': (toc - tic)
                        })
                        print "The worker has received new parameters. This took %f seconds." % (
                            toc - tic, )
                        continue
                    else:
                        print "Fatal error : invalid serialized_parameters_format : %s." % serialized_parameters_format
                        quit()

            elif next_action == "process_minibatch":

                remote_redis_logger.log('event', "process_minibatch")
                tic_process_minibatch = time.time()

                segment = sample_segment()
                queue_name = "L_workers_%s_minibatch_indices_QUEUE" % segment
                if rsconn.llen(queue_name) == 0:
                    msg = "The worker has nothing to do.\nThe queue %s is empty." % queue_name
                    print msg
                    remote_redis_logger.log('event', msg)
                    # TODO : Adjust the duration of the sleep.
                    time.sleep(0.2)
                    continue

                current_minibatch_indices_str = rsconn.lpop(queue_name)
                if current_minibatch_indices_str is None or len(
                        current_minibatch_indices_str) == 0:
                    # This is very unexpected, because it implies that we have a queue
                    # that is shorter than the number of workers. It's not illegal, but
                    # just generally not recommended for a setup.
                    msg = "The worker has nothing to do.\nIt is as though queue %s was empty when we tried to pop an element from the left." % queue_name
                    print msg
                    remote_redis_logger.log('event', msg)
                    # TODO : Adjust the duration of the sleep.
                    time.sleep(0.2)
                    continue

                current_minibatch_indices = np.fromstring(
                    current_minibatch_indices_str, dtype=np.int32)

                # There is a special thing to do with the individual_importance_weight.
                # We want to keep around their previous values.

                tic = time.time()
                tmp_str = rsconn.hget(
                    "H_%s_minibatch_%s" %
                    (segment, "individual_importance_weight"),
                    current_minibatch_indices_str)
                rsconn.hset(
                    "H_%s_minibatch_%s" %
                    (segment, "previous_individual_importance_weight"),
                    current_minibatch_indices_str, tmp_str)
                toc = time.time()
                remote_redis_logger.log(
                    'timing_profiler', {
                        'copied individual_importance_weight to previous_individual_importance_weight':
                        (toc - tic)
                    })

                tic = time.time()
                # This returns a dictionary of numpy arrays.
                DA_measurements = model_api.worker_process_minibatch(
                    current_minibatch_indices, segment, L_measurements)
                toc = time.time()
                remote_redis_logger.log(
                    'timing_profiler',
                    {'worker_process_minibatch': (toc - tic)})

                tic_send_measurements_to_database = time.time()
                # Update the measurements. Update the timestamps.
                # None of the measurements should be missing.
                for measurement in L_measurements:

                    A_values = DA_measurements[measurement]
                    assert type(
                        A_values
                    ) == np.ndarray, "Your `worker_process_minibatch` function is supposed to return an array of np.float32 as measurements (%s), but now those values are not even numpy arrays. They are %s instead." % (
                        measurement, type(A_values))
                    if A_values.dtype == np.float64:
                        # This conversion is acceptable.
                        A_values = A_values.astype(np.float32)
                    assert A_values.dtype == np.float32, "Your `worker_process_minibatch` function is supposed to return an array of np.float32 as measurements (%s), but now that array has dtype %s instead." % (
                        measurement, A_values.dtype)

                    number_of_invalid_values = np.logical_not(
                        np.isfinite(A_values)).sum()
                    if 0 < number_of_invalid_values:
                        msg = "FATAL ERROR. You have %d invalid values returned for %s." % (
                            number_of_invalid_values, measurement)
                        print msg
                        print A_values
                        remote_redis_logger.log('event', msg)
                        quit()
                        #print "Starting debugger."
                        #import pdb; pdb.set_trace()

                    rsconn.hset("H_%s_minibatch_%s" % (segment, measurement),
                                current_minibatch_indices_str,
                                A_values.tostring(order='C'))

                    previous_update_timestamp_str = rsconn.hget(
                        "H_%s_minibatch_%s_measurement_last_update_timestamp" %
                        (segment, measurement), current_minibatch_indices_str)
                    if previous_update_timestamp_str is None or len(
                            previous_update_timestamp_str) == 0:
                        print "The measurements are supposed to be initialized when starting the database."
                        print "They are supposed to have a timestamp set at that time."
                        print "This is not a serious error from which we could not recover, but it signals that there is a bug, so let's quit() here."
                        quit()
                        #previous_update_timestamp = 0.0
                    else:
                        previous_update_timestamp = float(
                            previous_update_timestamp_str)

                    print "(%s, %s) timestamp delta between updates to that measurement : %f" % (
                        segment,
                        measurement,
                        time.time() - previous_update_timestamp,
                    )

                    current_update_timestamp = time.time()
                    rsconn.hset(
                        "H_%s_minibatch_%s_measurement_last_update_timestamp" %
                        (segment, measurement), current_minibatch_indices_str,
                        current_update_timestamp)

                    # This string is reflected intact to the database. It does not get parsed. The usefulness of this comes
                    # from the fact that the master can then how many "minibatches ago" the parameters came.
                    # This is basically the same thing as the "delay_between_measurement_update_and_parameter_update"
                    # below, but it's an absolute value instead of being a difference.
                    rsconn.hset(
                        "H_%s_minibatch_%s_measurement_num_minibatches_master_processed"
                        % (segment, measurement),
                        current_minibatch_indices_str,
                        parameters_num_minibatches_master_processed_str)

                    delay_between_measurement_update = current_update_timestamp - previous_update_timestamp
                    delay_between_measurement_update_and_parameter_update = current_update_timestamp - float(
                        parameters_current_timestamp_str)

                    rsconn.hset(
                        "H_%s_minibatch_%s_delay_between_measurement_update" %
                        (segment, measurement), current_minibatch_indices_str,
                        delay_between_measurement_update)
                    rsconn.hset(
                        "H_%s_minibatch_%s_delay_between_measurement_update_and_parameter_update"
                        % (segment, measurement),
                        current_minibatch_indices_str,
                        delay_between_measurement_update_and_parameter_update)

                    #print "delay_between_measurement_update : %f seconds" % delay_between_measurement_update
                    #print "delay_between_measurement_update_and_parameter_update : %f seconds" % delay_between_measurement_update_and_parameter_update

                    # Be careful. If you re-indent the next block deeper,
                    # you'll mess up everything with the re-queuing of the minibatches.

                toc_send_measurements_to_database = time.time()
                remote_redis_logger.log(
                    'timing_profiler', {
                        'send_measurements_to_database':
                        (toc_send_measurements_to_database -
                         tic_send_measurements_to_database)
                    })
                # We could log this for every measurement, but we'll just log it for one of them.
                # Otherwise, this is multiplying the messaging without real need.
                # We'll use the values for the last measurement, which outlasts the for loop above
                # due to shitty python scoping.
                remote_redis_logger.log(
                    'delay', {
                        'delay_between_measurement_update':
                        delay_between_measurement_update,
                        'delay_between_measurement_update_and_parameter_update':
                        delay_between_measurement_update_and_parameter_update
                    })

                # Push back that minibatch to the right of the queue.
                # It will eventually find its way back to some worker,
                # but we will cover all the other ones before that happens.
                rsconn.rpush(queue_name, current_minibatch_indices_str)
                toc_process_minibatch = time.time()
                remote_redis_logger.log(
                    'timing_profiler', {
                        'process_minibatch':
                        (toc_process_minibatch - tic_process_minibatch)
                    })
                msg = "Processed one minibatch from %s. Pushed back to back of the line. Total time taken is %f seconds." % (
                    segment, toc_process_minibatch - tic_process_minibatch)
                print msg
                remote_redis_logger.log('event', msg)