def query_and_train( input_table_spec: TableSpec, model: ModelManager__Union, num_epochs: int, use_gpu: bool, *, setup_data: Optional[Dict[str, bytes]] = None, saved_setup_data: Optional[Dict[str, bytes]] = None, normalization_data_map: Optional[Dict[str, NormalizationData]] = None, reward_options: Optional[RewardOptions] = None, reader_options: Optional[ReaderOptions] = None, resource_options: Optional[ResourceOptions] = None, warmstart_path: Optional[str] = None, validator: Optional[ModelValidator__Union] = None, publisher: Optional[ModelPublisher__Union] = None, named_model_ids: Optional[ModuleNameToEntityId] = None, recurring_period: Optional[RecurringPeriod] = None, ) -> RLTrainingOutput: child_workflow_id = get_workflow_id() if named_model_ids is None: named_model_ids = get_new_named_entity_ids(model.value.serving_module_names()) logger.info("Starting query") reward_options = reward_options or RewardOptions() reader_options = reader_options or ReaderOptions() resource_options = resource_options or ResourceOptions() manager = model.value if saved_setup_data is not None: def _maybe_get_bytes(v) -> bytes: if isinstance(v, bytes): return v # HACK: FBLearner sometimes pack bytes into Blob return v.data saved_setup_data = {k: _maybe_get_bytes(v) for k, v in saved_setup_data.items()} if setup_data is None: data_module = manager.get_data_module( input_table_spec=input_table_spec, reward_options=reward_options, reader_options=reader_options, saved_setup_data=saved_setup_data, ) if data_module is not None: setup_data = data_module.prepare_data() # Throw away existing normalization data map normalization_data_map = None if sum([int(setup_data is not None), int(normalization_data_map is not None)]) != 1: raise ValueError("setup_data and normalization_data_map are mutually exclusive") train_dataset = None eval_dataset = None if normalization_data_map is not None: calc_cpe_in_training = manager.should_generate_eval_dataset sample_range_output = get_sample_range(input_table_spec, calc_cpe_in_training) train_dataset = manager.query_data( input_table_spec=input_table_spec, sample_range=sample_range_output.train_sample_range, reward_options=reward_options, ) eval_dataset = None if calc_cpe_in_training: eval_dataset = manager.query_data( input_table_spec=input_table_spec, sample_range=sample_range_output.eval_sample_range, reward_options=reward_options, ) logger.info("Starting training") results = manager.train_workflow( train_dataset, eval_dataset, num_epochs=num_epochs, use_gpu=use_gpu, setup_data=setup_data, normalization_data_map=normalization_data_map, named_model_ids=named_model_ids, child_workflow_id=child_workflow_id, reward_options=reward_options, reader_options=reader_options, resource_options=resource_options, warmstart_path=warmstart_path, ) if validator is not None: results = run_validator(validator, results) if publisher is not None: results = run_publisher( publisher, model, results, named_model_ids, child_workflow_id, recurring_period, ) return results
def query_and_train( input_table_spec: TableSpec, model: ModelManager__Union, normalization_data_map: Dict[str, NormalizationData], num_epochs: int, use_gpu: bool, reward_options: Optional[RewardOptions] = None, reader_options: Optional[ReaderOptions] = None, resource_options: Optional[ResourceOptions] = None, warmstart_path: Optional[str] = None, validator: Optional[ModelValidator__Union] = None, publisher: Optional[ModelPublisher__Union] = None, parent_workflow_id: Optional[int] = None, recurring_period: Optional[RecurringPeriod] = None, ) -> RLTrainingOutput: child_workflow_id = get_workflow_id() if parent_workflow_id is None: parent_workflow_id = child_workflow_id logger.info("Starting query") reward_options = reward_options or RewardOptions() reader_options = reader_options or ReaderOptions() resource_options = resource_options or ResourceOptions() manager = model.value calc_cpe_in_training = manager.should_generate_eval_dataset sample_range_output = get_sample_range(input_table_spec, calc_cpe_in_training) train_dataset = manager.query_data( input_table_spec=input_table_spec, sample_range=sample_range_output.train_sample_range, reward_options=reward_options, ) eval_dataset = None if calc_cpe_in_training: eval_dataset = manager.query_data( input_table_spec=input_table_spec, sample_range=sample_range_output.eval_sample_range, reward_options=reward_options, ) logger.info("Starting training") results = manager.train_workflow( train_dataset, eval_dataset, normalization_data_map, num_epochs, use_gpu, parent_workflow_id=parent_workflow_id, child_workflow_id=child_workflow_id, reward_options=reward_options, reader_options=reader_options, resource_options=resource_options, warmstart_path=warmstart_path, ) if validator is not None: results = run_validator(validator, results) if publisher is not None: results = run_publisher( publisher, model, results, parent_workflow_id, child_workflow_id, recurring_period, ) return results