Ejemplo n.º 1
0
    def to_proto(self):
        d = {
            'train_py': self.script_locations.train_py,
            'eval_py': self.script_locations.eval_py,
            'export_py': self.script_locations.export_py,
            'train_restart_dir': self.train_options.train_restart_dir,
            'sync_interval': self.train_options.sync_interval,
            'do_monitoring': self.train_options.do_monitoring,
            'do_eval': self.train_options.do_eval,
            'replace_model': self.train_options.replace_model,
            'debug': self.debug,
            'training_data_uri': self.training_data_uri,
            'training_output_uri': self.training_output_uri,
            'model_uri': self.model_uri,
            'fine_tune_checkpoint_name': self.fine_tune_checkpoint_name,
            'tfdl_config': self.tfdl_config
        }

        conf = json_format.ParseDict(d, BackendConfigMsg.TFDeeplabConfig())

        msg = BackendConfigMsg(backend_type=rv.TF_DEEPLAB,
                               tf_deeplab_config=conf)

        if self.pretrained_model_uri:
            msg.MergeFrom(
                BackendConfigMsg(
                    pretrained_model_uri=self.pretrained_model_uri))

        return msg
Ejemplo n.º 2
0
    def to_proto(self):
        d = {
            'sync_interval': self.train_options.sync_interval,
            'do_monitoring': self.train_options.do_monitoring,
            'replace_model': self.train_options.replace_model,
            'model_main_py': self.script_locations.model_main_uri,
            'export_py': self.script_locations.export_uri,
            'training_data_uri': self.training_data_uri,
            'training_output_uri': self.training_output_uri,
            'model_uri': self.model_uri,
            'debug': self.debug,
            'fine_tune_checkpoint_name': self.fine_tune_checkpoint_name,
            'tfod_config': self.tfod_config
        }

        conf = json_format.ParseDict(
            d, BackendConfigMsg.TFObjectDetectionConfig())

        msg = BackendConfigMsg(
            backend_type=rv.TF_OBJECT_DETECTION,
            tf_object_detection_config=conf)

        if self.pretrained_model_uri:
            msg.MergeFrom(
                BackendConfigMsg(
                    pretrained_model_uri=self.pretrained_model_uri))

        return msg
    def to_proto(self):
        struct = struct_pb2.Struct()
        struct['sync_interval'] = self.train_options.sync_interval
        struct['do_monitoring'] = self.train_options.do_monitoring
        struct['replace_model'] = self.train_options.replace_model
        struct['model_uri'] = self.model_uri
        struct['debug'] = self.debug
        struct['training_data_uri'] = self.training_data_uri
        struct['training_output_uri'] = self.training_output_uri

        msg = BackendConfigMsg(backend_type=self.backend_type,
                               custom_config=struct)

        if self.pretrained_model_uri:
            msg.MergeFrom(
                BackendConfigMsg(
                    pretrained_model_uri=self.pretrained_model_uri))
        return msg
Ejemplo n.º 4
0
    def to_proto(self):
        struct = struct_pb2.Struct()
        struct['scenes'] = list(
            map(lambda x: json_format.MessageToDict(x.to_proto()),
                self.scenes))
        struct['batch_size'] = self.batch_size
        struct['epochs'] = self.epochs
        struct['epoch_size'] = self.epoch_size
        struct['epoch_save_rate'] = self.epoch_save_rate
        struct['training_output_uri'] = self.training_output_uri
        struct['model_uri'] = self.model_uri

        msg = BackendConfigMsg(backend_type=self.backend_type,
                               custom_config=struct)

        if self.pretrained_model_uri:
            msg.MergeFrom(
                BackendConfigMsg(
                    pretrained_model_uri=self.pretrained_model_uri))
        return msg
    def to_proto(self):
        d = {
            'sync_interval': self.train_options.sync_interval,
            'do_monitoring': self.train_options.do_monitoring,
            'replace_model': self.train_options.replace_model,
            'training_data_uri': self.training_data_uri,
            'training_output_uri': self.training_output_uri,
            'model_uri': self.model_uri,
            'debug': self.debug,
            'kc_config': self.kc_config
        }

        conf = json_format.ParseDict(
            d, BackendConfigMsg.KerasClassificationConfig())

        msg = BackendConfigMsg(backend_type=rv.KERAS_CLASSIFICATION,
                               keras_classification_config=conf)

        if self.pretrained_model_uri:
            msg.MergeFrom(
                BackendConfigMsg(
                    pretrained_model_uri=self.pretrained_model_uri))

        return msg