Example #1
0
def _validate(graph_manager, task_parameters, transitions, s3_bucket,
              s3_prefix, aws_region):
    checkpoint_dir = task_parameters.checkpoint_restore_path
    wait_for_checkpoint(checkpoint_dir, graph_manager.data_store)

    if utils.do_model_selection(s3_bucket=s3_bucket,
                                s3_prefix=s3_prefix,
                                region=aws_region,
                                checkpoint_type=LAST_CHECKPOINT):
        logger.info(
            "Test Last Checkpoint: %s",
            utils.get_best_checkpoint(s3_bucket, s3_prefix, aws_region))
        graph_manager.create_graph(task_parameters)
        graph_manager.phase = RunPhase.TEST
        graph_manager.emulate_act_on_trainer(EnvironmentSteps(1),
                                             transitions=transitions)
        logger.info(
            "Test Best Checkpoint: %s",
            utils.get_last_checkpoint(s3_bucket, s3_prefix, aws_region))
        utils.do_model_selection(s3_bucket=s3_bucket,
                                 s3_prefix=s3_prefix,
                                 region=aws_region,
                                 checkpoint_type=BEST_CHECKPOINT)
        graph_manager.data_store.load_from_store()
        graph_manager.restore_checkpoint()
        graph_manager.emulate_act_on_trainer(EnvironmentSteps(1),
                                             transitions=transitions)
    else:
        logger.info("Test Last Checkpoint")
        graph_manager.create_graph(task_parameters)
        graph_manager.phase = RunPhase.TEST
        graph_manager.emulate_act_on_trainer(EnvironmentSteps(1),
                                             transitions=transitions)
Example #2
0
def _validate(graph_manager, task_parameters, transitions, s3_bucket,
              s3_prefix, aws_region):
    checkpoint_dir = task_parameters.checkpoint_restore_path
    wait_for_checkpoint(checkpoint_dir, graph_manager.data_store)

    if utils.do_model_selection(s3_bucket=s3_bucket,
                                s3_prefix=s3_prefix,
                                region=aws_region,
                                checkpoint_type=LAST_CHECKPOINT):
        screen.log_title(" Validating Last Checkpoint: {}".format(
            utils.get_last_checkpoint(s3_bucket, s3_prefix, aws_region)))
        graph_manager.create_graph(task_parameters)
        graph_manager.phase = RunPhase.TEST
        screen.log_title(" Start emulate_act_on_trainer on Last Checkpoint")
        graph_manager.emulate_act_on_trainer(EnvironmentSteps(1),
                                             transitions=transitions)
        screen.log_title(
            " emulate_act_on_trainer on Last Checkpoint completed!")
        # Best checkpoint might not exist.
        if utils.do_model_selection(s3_bucket=s3_bucket,
                                    s3_prefix=s3_prefix,
                                    region=aws_region,
                                    checkpoint_type=BEST_CHECKPOINT):
            screen.log_title(" Validating Best Checkpoint: {}".format(
                utils.get_best_checkpoint(s3_bucket, s3_prefix, aws_region)))
            graph_manager.data_store.load_from_store()
            graph_manager.restore_checkpoint()
            screen.log_title(
                " Start emulate_act_on_trainer on Best Checkpoint")
            graph_manager.emulate_act_on_trainer(EnvironmentSteps(1),
                                                 transitions=transitions)
            screen.log_title(
                " emulate_act_on_trainer on Best Checkpoint completed!")
        else:
            screen.log_title(" No Best Checkpoint to validate.")

    else:
        screen.log_title(" Validating Last Checkpoint")
        graph_manager.create_graph(task_parameters)
        graph_manager.phase = RunPhase.TEST
        screen.log_title(" Start emulate_act_on_trainer on Last Checkpoint ")
        graph_manager.emulate_act_on_trainer(EnvironmentSteps(1),
                                             transitions=transitions)
        screen.log_title(
            " Start emulate_act_on_trainer on Last Checkpoint completed!")
    screen.log_title(" Validation completed!")
    def save_to_store(self):
        try:
            s3_client = self._get_client()
            base_checkpoint_dir = self.params.base_checkpoint_dir
            for agent_key, bucket in self.params.buckets.items():
                # remove lock file if it exists
                s3_client.delete_object(Bucket=bucket, Key=self._get_s3_key(SyncFiles.LOCKFILE.value, agent_key))

                # acquire lock
                s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                         Bucket=bucket,
                                         Key=self._get_s3_key(SyncFiles.LOCKFILE.value, agent_key))

                checkpoint_dir = base_checkpoint_dir if len(self.graph_manager.agents_params) == 1 else \
                    os.path.join(base_checkpoint_dir, agent_key)

                state_file = CheckpointStateFile(os.path.abspath(checkpoint_dir))
                ckpt_state = None
                check_point_key_list = []
                if state_file.exists():
                    ckpt_state = state_file.read()
                    checkpoint_file = None
                    num_files_uploaded = 0
                    start_time = time.time()
                    for root, _, files in os.walk(checkpoint_dir):
                        for filename in files:
                            if filename == CheckpointStateFile.checkpoint_state_filename:
                                checkpoint_file = (root, filename)
                                continue
                            if filename.startswith(ckpt_state.name):
                                abs_name = os.path.abspath(os.path.join(root, filename))
                                rel_name = os.path.relpath(abs_name, checkpoint_dir)
                                s3_client.upload_file(Filename=abs_name,
                                                      Bucket=bucket,
                                                      Key=self._get_s3_key(rel_name, agent_key),
                                                      Config=boto3.s3.transfer.TransferConfig(multipart_threshold=1))
                                check_point_key_list.append(self._get_s3_key(rel_name, agent_key))
                                num_files_uploaded += 1
                    time_taken = time.time() - start_time
                    LOG.info("Uploaded %s files for checkpoint %s in %.2f seconds", num_files_uploaded, ckpt_state.num, time_taken)
                    if check_point_key_list:
                        self.delete_queues[agent_key].put(check_point_key_list)

                    abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
                    rel_name = os.path.relpath(abs_name, checkpoint_dir)
                    s3_client.upload_file(Filename=abs_name,
                                          Bucket=bucket,
                                          Key=self._get_s3_key(rel_name, agent_key))

                # upload Finished if present
                if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.FINISHED.value)):
                    s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                             Bucket=bucket,
                                             Key=self._get_s3_key(SyncFiles.FINISHED.value, agent_key))

                # upload Ready if present
                if os.path.exists(os.path.join(checkpoint_dir, SyncFiles.TRAINER_READY.value)):
                    s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
                                             Bucket=bucket,
                                             Key=self._get_s3_key(SyncFiles.TRAINER_READY.value, agent_key))

                # release lock
                s3_client.delete_object(Bucket=bucket,
                                        Key=self._get_s3_key(SyncFiles.LOCKFILE.value, agent_key))

                # Upload the frozen graph which is used for deployment
                if self.graph_manager:
                    # checkpoint state is always present for the checkpoint dir passed.
                    # We make same assumption while we get the best checkpoint in s3_metrics
                    checkpoint_num = ckpt_state.num
                    self.write_frozen_graph(self.graph_manager, agent_key, checkpoint_num)
                    frozen_name = "model_{}.pb".format(checkpoint_num)
                    frozen_graph_fpath = os.path.join(SM_MODEL_PB_TEMP_FOLDER, agent_key,
                                                      frozen_name)
                    frozen_graph_s3_name = frozen_name if len(self.graph_manager.agents_params) == 1 \
                        else os.path.join(agent_key, frozen_name)
                    # upload the model_<ID>.pb to S3.
                    s3_client.upload_file(Filename=frozen_graph_fpath,
                                          Bucket=bucket,
                                          Key=self._get_s3_key(frozen_graph_s3_name, agent_key))
                    LOG.info("saved intermediate frozen graph: %s", self._get_s3_key(frozen_graph_s3_name, agent_key))

                    # Copy the best checkpoint to the SM_MODEL_OUTPUT_DIR
                    copy_best_frozen_model_to_sm_output_dir(bucket,
                                                            self.params.s3_folders[agent_key],
                                                            self.params.aws_region,
                                                            os.path.join(SM_MODEL_PB_TEMP_FOLDER, agent_key),
                                                            os.path.join(SM_MODEL_OUTPUT_DIR, agent_key),
                                                            self.params.s3_endpoint_url)

                # Clean up old checkpoints
                if ckpt_state and self.delete_queues[agent_key].qsize() > NUM_MODELS_TO_KEEP:
                    best_checkpoint = get_best_checkpoint(bucket,
                                                          self.params.s3_folders[agent_key],
                                                          self.params.aws_region,
                                                          self.params.s3_endpoint_url)
                    while self.delete_queues[agent_key].qsize() > NUM_MODELS_TO_KEEP:
                        key_list = self.delete_queues[agent_key].get()
                        if best_checkpoint and all(list(map(lambda file_name: best_checkpoint in file_name,
                                                            [os.path.split(file)[-1] for file in key_list]))):
                            self.delete_queues[agent_key].put(key_list)
                        else:
                            delete_iteration_ids = set()
                            for key in key_list:
                                s3_client.delete_object(Bucket=bucket, Key=key)
                                # Get the name of the file in the checkpoint directory that has to be deleted
                                # and extract the iteration id out of the name
                                file_in_checkpoint_dir = os.path.split(key)[-1]
                                if len(file_in_checkpoint_dir.split("_Step")) > 0:
                                    delete_iteration_ids.add(file_in_checkpoint_dir.split("_Step")[0])
                            LOG.info("Deleting the frozen models in s3 for the iterations: %s",
                                     delete_iteration_ids)
                            # Delete the model_{}.pb files from the s3 bucket for the previous iterations
                            for iteration_id in list(delete_iteration_ids):
                                frozen_name = "model_{}.pb".format(iteration_id)
                                frozen_graph_s3_name = frozen_name if len(self.graph_manager.agents_params) == 1 \
                                    else os.path.join(agent_key, frozen_name)
                                s3_client.delete_object(Bucket=bucket,
                                                        Key=self._get_s3_key(frozen_graph_s3_name, agent_key))
        except botocore.exceptions.ClientError:
            log_and_exit("Unable to upload checkpoint",
                         SIMAPP_S3_DATA_STORE_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)
        except Exception:
            log_and_exit("Unable to upload checkpoint",
                         SIMAPP_S3_DATA_STORE_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)