def handle_distributed_coach_tasks(graph_manager, args, task_parameters): ckpt_inside_container = "/checkpoint" memory_backend_params = None if args.memory_backend_params: memory_backend_params = json.loads(args.memory_backend_params) memory_backend_params['run_type'] = str( args.distributed_coach_run_type) graph_manager.agent_params.memory.register_var( 'memory_backend_params', construct_memory_params(memory_backend_params)) data_store = None data_store_params = None if args.data_store_params: data_store_params = construct_data_store_params( json.loads(args.data_store_params)) data_store_params.expt_dir = args.experiment_path data_store_params.checkpoint_dir = ckpt_inside_container graph_manager.data_store_params = data_store_params data_store = get_data_store(data_store_params) if args.distributed_coach_run_type == RunType.TRAINER: task_parameters.checkpoint_save_dir = ckpt_inside_container training_worker(graph_manager=graph_manager, data_store=data_store, task_parameters=task_parameters, is_multi_node_test=args.is_multi_node_test) if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER: rollout_worker(graph_manager=graph_manager, data_store=data_store, num_workers=args.num_workers, task_parameters=task_parameters)
def handle_distributed_coach_tasks(graph_manager, args): ckpt_inside_container = "/checkpoint" memory_backend_params = None if args.memory_backend_params: memory_backend_params = json.loads(args.memory_backend_params) memory_backend_params['run_type'] = str(args.distributed_coach_run_type) graph_manager.agent_params.memory.register_var('memory_backend_params', construct_memory_params(memory_backend_params)) data_store_params = None if args.data_store_params: data_store_params = construct_data_store_params(json.loads(args.data_store_params)) data_store_params.checkpoint_dir = ckpt_inside_container graph_manager.data_store_params = data_store_params if args.distributed_coach_run_type == RunType.TRAINER: training_worker( graph_manager=graph_manager, checkpoint_dir=ckpt_inside_container ) if args.distributed_coach_run_type == RunType.ROLLOUT_WORKER: data_store = None if args.data_store_params: data_store = get_data_store(data_store_params) wait_for_checkpoint(checkpoint_dir=ckpt_inside_container, data_store=data_store) rollout_worker( graph_manager=graph_manager, checkpoint_dir=ckpt_inside_container, data_store=data_store, num_workers=args.num_workers )