def make_learner_thread(local_worker, config): if not config["simple_optimizer"]: logger.info( "Enabling multi-GPU mode, {} GPUs, {} parallel tower-stacks". format(config["num_gpus"], config["num_multi_gpu_tower_stacks"])) num_stacks = config["num_multi_gpu_tower_stacks"] buffer_size = config["minibatch_buffer_size"] if num_stacks < buffer_size: logger.warning( "In multi-GPU mode you should have at least as many " "multi-GPU tower stacks (to load data into on one device) as " "you have stack-index slots in the buffer! You have " f"configured {num_stacks} stacks and a buffer of size " f"{buffer_size}. Setting " f"`minibatch_buffer_size={num_stacks}`.") config["minibatch_buffer_size"] = num_stacks learner_thread = MultiGPULearnerThread( local_worker, num_gpus=config["num_gpus"], lr=config["lr"], train_batch_size=config["train_batch_size"], num_multi_gpu_tower_stacks=config["num_multi_gpu_tower_stacks"], num_sgd_iter=config["num_sgd_iter"], learner_queue_size=config["learner_queue_size"], learner_queue_timeout=config["learner_queue_timeout"]) else: learner_thread = LearnerThread( local_worker, minibatch_buffer_size=config["minibatch_buffer_size"], num_sgd_iter=config["num_sgd_iter"], learner_queue_size=config["learner_queue_size"], learner_queue_timeout=config["learner_queue_timeout"]) return learner_thread
def make_learner_thread(local_worker, config): if not config["simple_optimizer"]: logger.info( "Enabling multi-GPU mode, {} GPUs, {} parallel tower-stacks". format(config["num_gpus"], config["num_multi_gpu_tower_stacks"])) if config["num_multi_gpu_tower_stacks"] < \ config["minibatch_buffer_size"]: raise ValueError( "In multi-GPU mode you must have at least as many " "parallel multi-GPU towers as minibatch buffers: " "{} vs {}".format(config["num_multi_gpu_tower_stacks"], config["minibatch_buffer_size"])) learner_thread = MultiGPULearnerThread( local_worker, num_gpus=config["num_gpus"], lr=config["lr"], train_batch_size=config["train_batch_size"], num_multi_gpu_tower_stacks=config["num_multi_gpu_tower_stacks"], minibatch_buffer_size=config["minibatch_buffer_size"], num_sgd_iter=config["num_sgd_iter"], learner_queue_size=config["learner_queue_size"], learner_queue_timeout=config["learner_queue_timeout"]) else: learner_thread = LearnerThread( local_worker, minibatch_buffer_size=config["minibatch_buffer_size"], num_sgd_iter=config["num_sgd_iter"], learner_queue_size=config["learner_queue_size"], learner_queue_timeout=config["learner_queue_timeout"]) return learner_thread