示例#1
0
文件: train.py 项目: XinYao1994/sedna
def main():
    # load dataset.
    train_dataset_url = BaseConfig.train_dataset_url
    test_dataset_url = BaseConfig.test_dataset_url

    train_data = TxtDataParse(data_type="train", func=image_process)
    train_data.parse(train_dataset_url)

    valid_data = TxtDataParse(data_type="test", func=image_process)
    valid_data.parse(test_dataset_url)

    epochs = int(Context.get_parameters("epochs", 1))
    batch_size = int(Context.get_parameters("batch_size", 1))
    aggregation_algorithm = Context.get_parameters("aggregation_algorithm",
                                                   "FedAvg")
    learning_rate = float(Context.get_parameters("learning_rate", 0.001))
    validation_split = float(Context.get_parameters("validation_split", 0.2))

    fl_model = FederatedLearning(estimator=Estimator,
                                 aggregation=aggregation_algorithm)

    train_jobs = fl_model.train(train_data=train_data,
                                valid_data=valid_data,
                                epochs=epochs,
                                batch_size=batch_size,
                                learning_rate=learning_rate,
                                validation_split=validation_split)

    return train_jobs
示例#2
0
文件: train.py 项目: XinYao1994/sedna
def main():
    # load dataset.
    train_dataset_url = BaseConfig.train_dataset_url
    train_data = CSVDataParse(data_type="train", func=feature_process)
    train_data.parse(train_dataset_url, label=DATACONF["LABEL"])
    attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]})
    early_stopping_rounds = int(
        Context.get_parameters("early_stopping_rounds", 100))
    metric_name = Context.get_parameters("metric_name", "mlogloss")

    task_definition = {
        "method": "TaskDefinitionByDataAttr",
        "param": attribute
    }

    ll_job = LifelongLearning(estimator=Estimator,
                              task_definition=task_definition,
                              task_relationship_discovery=None,
                              task_mining=None,
                              task_remodeling=None,
                              inference_integrate=None,
                              unseen_task_detect=None)
    train_experiment = ll_job.train(
        train_data=train_data,
        metric_name=metric_name,
        early_stopping_rounds=early_stopping_rounds)

    return train_experiment
示例#3
0
def run():
    camera_address = Context.get_parameters('video_url')

    # get hard exmaple mining algorithm from config
    hard_example_mining = IncrementalLearning.get_hem_algorithm_from_config(
        threshold_img=0.9)

    input_shape_str = Context.get_parameters("input_shape")
    input_shape = tuple(int(v) for v in input_shape_str.split(","))
    # create Incremental Learning instance
    incremental_instance = IncrementalLearning(
        estimator=Estimator, hard_example_mining=hard_example_mining)
    # use video streams for testing
    camera = cv2.VideoCapture(camera_address)
    fps = 10
    nframe = 0
    # the input of video stream
    while 1:
        ret, input_yuv = camera.read()
        if not ret:
            time.sleep(5)
            camera = cv2.VideoCapture(camera_address)
            continue

        if nframe % fps:
            nframe += 1
            continue

        img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB)
        nframe += 1
        if nframe % 1000 == 1:  # logs every 1000 frames
            warnings.warn(f"camera is open, current frame index is {nframe}")
        results, _, is_hard_example = incremental_instance.inference(
            img_rgb, post_process=deal_infer_rsl, input_shape=input_shape)
        output_deal(is_hard_example, results, nframe, img_rgb)
示例#4
0
    def __init__(self,
                 estimator,
                 task_definition=None,
                 task_relationship_discovery=None,
                 task_mining=None,
                 task_remodeling=None,
                 inference_integrate=None,
                 unseen_task_detect=None):

        if not task_definition:
            task_definition = {"method": "TaskDefinitionByDataAttr"}
        if not unseen_task_detect:
            unseen_task_detect = {"method": "TaskAttrFilter"}
        e = MulTaskLearning(
            estimator=estimator,
            task_definition=task_definition,
            task_relationship_discovery=task_relationship_discovery,
            task_mining=task_mining,
            task_remodeling=task_remodeling,
            inference_integrate=inference_integrate)
        self.unseen_task_detect = unseen_task_detect.get(
            "method", "TaskAttrFilter")
        self.unseen_task_detect_param = e._parse_param(
            unseen_task_detect.get("param", {}))
        config = dict(ll_kb_server=Context.get_parameters("KB_SERVER"),
                      output_url=Context.get_parameters("OUTPUT_URL", "/tmp"))
        task_index = FileOps.join_path(config['output_url'],
                                       KBResourceConstant.KB_INDEX_NAME.value)
        config['task_index'] = task_index
        super(LifelongLearning, self).__init__(estimator=e, config=config)
        self.job_kind = K8sResourceKind.LIFELONG_JOB.value
        self.kb_server = KBClient(kbserver=self.config.ll_kb_server)
示例#5
0
def run_server():
    aggregation_algorithm = Context.get_parameters("aggregation_algorithm",
                                                   "FedAvg")
    exit_round = int(Context.get_parameters("exit_round", 3))
    participants_count = int(Context.get_parameters("participants_count", 1))

    server = AggregationServer(aggregation=aggregation_algorithm,
                               exit_round=exit_round,
                               ws_size=20 * 1024 * 1024,
                               participants_count=participants_count)
    server.start()
示例#6
0
def main():

    utd = Context.get_parameters("UTD_NAME", "TaskAttrFilter")
    attribute = json.dumps({"attribute": DATACONF["ATTRIBUTES"]})
    utd_parameters = Context.get_parameters("UTD_PARAMETERS", {})
    ut_saved_url = Context.get_parameters("UTD_SAVED_URL", "/tmp")

    task_mining = {"method": "TaskMiningByDataAttr", "param": attribute}

    unseen_task_detect = {"method": utd, "param": utd_parameters}

    ll_service = LifelongLearning(estimator=Estimator,
                                  task_mining=task_mining,
                                  task_definition=None,
                                  task_relationship_discovery=None,
                                  task_remodeling=None,
                                  inference_integrate=None,
                                  unseen_task_detect=unseen_task_detect)

    infer_dataset_url = Context.get_parameters('infer_dataset_url')
    file_handle = open(infer_dataset_url, "r", encoding="utf-8")
    header = list(csv.reader([file_handle.readline().strip()]))[0]
    infer_data = CSVDataParse(data_type="test", func=feature_process)

    unseen_sample = open(os.path.join(ut_saved_url, "unseen_sample.csv"),
                         "w",
                         encoding="utf-8")
    unseen_sample.write("\t".join(header + ['pred']) + "\n")
    output_sample = open(f"{infer_dataset_url}_out.csv", "w", encoding="utf-8")
    output_sample.write("\t".join(header + ['pred']) + "\n")

    while 1:
        where = file_handle.tell()
        line = file_handle.readline()
        if not line:
            time.sleep(1)
            file_handle.seek(where)
            continue
        reader = list(csv.reader([line.strip()]))
        rows = reader[0]
        data = dict(zip(header, rows))
        infer_data.parse(data, label=DATACONF["LABEL"])
        rsl, is_unseen, target_task = ll_service.inference(infer_data)

        rows.append(list(rsl)[0])

        output = "\t".join(map(str, rows)) + "\n"
        if is_unseen:
            unseen_sample.write(output)
        output_sample.write(output)
    unseen_sample.close()
    output_sample.close()
示例#7
0
    def __init__(self, data=None, estimator=None,
                 aggregation=None, transmitter=None) -> None:

        from plato.config import Config
        from plato.datasources import base
        # set parameters
        server = Config().server._asdict()
        clients = Config().clients._asdict()
        datastore = Config().data._asdict()
        train = Config().trainer._asdict()
        self.datasource = None
        if data is not None:
            if hasattr(data, "customized"):
                if data.customized:
                    self.datasource = base.DataSource()
                    self.datasource.trainset = data.trainset
                    self.datasource.testset = data.testset
            else:
                datastore.update(data.parameters)
                Config().data = Config.namedtuple_from_dict(datastore)

        self.model = None
        if estimator is not None:
            self.model = estimator.model
            train.update(estimator.hyperparameters)
            Config().trainer = Config.namedtuple_from_dict(train)

        if aggregation is not None:
            Config().algorithm = Config.namedtuple_from_dict(
                aggregation.parameters)
            if aggregation.parameters["type"] == "mistnet":
                clients["type"] = "mistnet"
                server["type"] = "mistnet"
            else:
                clients["do_test"] = True

        server["address"] = Context.get_parameters("AGG_IP")
        server["port"] = Context.get_parameters("AGG_PORT")

        if transmitter is not None:
            server.update(transmitter.parameters)

        Config().server = Config.namedtuple_from_dict(server)
        Config().clients = Config.namedtuple_from_dict(clients)

        from plato.clients import registry as client_registry
        self.client = client_registry.get(model=self.model,
                                          datasource=self.datasource)
        self.client.configure()
示例#8
0
    def __init__(self, data=None, estimator=None,
                 aggregation=None, transmitter=None,
                 chooser=None) -> None:
        from plato.config import Config
        # set parameters
        server = Config().server._asdict()
        clients = Config().clients._asdict()
        datastore = Config().data._asdict()
        train = Config().trainer._asdict()

        if data is not None:
            datastore.update(data.parameters)
            Config().data = Config.namedtuple_from_dict(datastore)

        self.model = None
        if estimator is not None:
            self.model = estimator.model
            if estimator.pretrained is not None:
                Config().params['pretrained_model_dir'] = estimator.pretrained
            if estimator.saved is not None:
                Config().params['model_dir'] = estimator.saved
            train.update(estimator.hyperparameters)
            Config().trainer = Config.namedtuple_from_dict(train)

        server["address"] = Context.get_parameters("AGG_BIND_IP", "0.0.0.0")
        server["port"] = int(Context.get_parameters("AGG_BIND_PORT", 7363))
        if transmitter is not None:
            server.update(transmitter.parameters)

        if aggregation is not None:
            Config().algorithm = Config.namedtuple_from_dict(
                aggregation.parameters)
            if aggregation.parameters["type"] == "mistnet":
                clients["type"] = "mistnet"
                server["type"] = "mistnet"
            else:
                clients["do_test"] = True

        if chooser is not None:
            clients["per_round"] = chooser.parameters["per_round"]

        LOGGER.info("address %s, port %s", server["address"], server["port"])

        Config().server = Config.namedtuple_from_dict(server)
        Config().clients = Config.namedtuple_from_dict(clients)

        from plato.servers import registry as server_registry
        self.server = server_registry.get(model=self.model)
示例#9
0
    def __init__(self,
                 estimator=None,
                 task_definition=None,
                 task_relationship_discovery=None,
                 task_mining=None,
                 task_remodeling=None,
                 inference_integrate=None):

        self.task_definition = task_definition or {
            "method": "TaskDefinitionByDataAttr"
        }
        self.task_relationship_discovery = task_relationship_discovery or {
            "method": "DefaultTaskRelationDiscover"
        }
        self.task_mining = task_mining or {}
        self.task_remodeling = task_remodeling or {
            "method": "DefaultTaskRemodeling"
        }
        self.inference_integrate = inference_integrate or {
            "method": "DefaultInferenceIntegrate"
        }
        self.models = None
        self.extractor = None
        self.base_model = estimator
        self.task_groups = None
        self.task_index_url = KBResourceConstant.KB_INDEX_NAME.value
        self.min_train_sample = int(
            Context.get_parameters("MIN_TRAIN_SAMPLE",
                                   KBResourceConstant.MIN_TRAIN_SAMPLE.value))
示例#10
0
def main():
    # load dataset.
    test_dataset_url = Context.get_parameters('test_dataset_url')

    valid_data = TxtDataParse(data_type="test", func=_load_txt_dataset)
    valid_data.parse(test_dataset_url, use_raw=True)

    # read parameters from deployment config.
    class_names = Context.get_parameters("class_names")
    class_names = [label.strip() for label in class_names.split(',')]
    input_shape = Context.get_parameters("input_shape")
    input_shape = tuple(int(shape) for shape in input_shape.split(','))

    incremental_instance = IncrementalLearning(estimator=Estimator)
    return incremental_instance.evaluate(valid_data, class_names=class_names,
                                         input_shape=input_shape)
示例#11
0
 def __init__(self) -> None:
     self.model = None
     self.pretrained = None
     self.saved = None
     self.hyperparameters = {
         "type": "yolov5",
         "rounds": 1,
         "target_accuracy": 0.99,
         "epochs": int(Context.get_parameters("EPOCHS", 500)),
         "batch_size": int(Context.get_parameters("BATCH_SIZE", 16)),
         "optimizer": "SGD",
         "linear_lr": False,
         # The machine learning model
         "model_name": "yolov5",
         "model_config": "./yolov5s.yaml",
         "train_params": "./hyp.scratch.yaml"
     }
示例#12
0
    def __init__(self, estimator, aggregation="FedAvg"):

        protocol = Context.get_parameters("AGG_PROTOCOL", "ws")
        agg_ip = Context.get_parameters("AGG_IP", "127.0.0.1")
        agg_port = int(Context.get_parameters("AGG_PORT", "7363"))
        agg_uri = f"{protocol}://{agg_ip}:{agg_port}/{aggregation}"
        config = dict(
            protocol=protocol,
            agg_ip=agg_ip,
            agg_port=agg_port,
            agg_uri=agg_uri
        )
        super(FederatedLearning, self).__init__(
            estimator=estimator, config=config)
        self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation)

        connect_timeout = int(Context.get_parameters("CONNECT_TIMEOUT", "300"))
        self.node = None
        self.register(timeout=connect_timeout)
示例#13
0
文件: base.py 项目: XinYao1994/sedna
 def __init__(self, estimator, callback=None, version="latest"):
     self.run_flag = True
     hot_update_conf = Context.get_parameters("MODEL_HOT_UPDATE_CONFIG")
     if not hot_update_conf:
         LOGGER.error("As `MODEL_HOT_UPDATE_CONF` unset a value, skipped")
         self.run_flag = False
     model_check_time = int(
         Context.get_parameters("MODEL_POLL_PERIOD_SECONDS", "60"))
     if model_check_time < 1:
         LOGGER.warning("Catch an abnormal value in "
                        "`MODEL_POLL_PERIOD_SECONDS`, fallback with 60")
         model_check_time = 60
     self.hot_update_conf = hot_update_conf
     self.check_time = model_check_time
     self.production_estimator = estimator
     self.callback = callback
     self.version = version
     self.temp_path = tempfile.gettempdir()
     super(ModelLoadingThread, self).__init__()
示例#14
0
 def __init__(self) -> None:
     self.model = self.build()
     self.pretrained = None
     self.saved = None
     self.hyperparameters = {
         "use_tensorflow": True,
         "is_compiled": True,
         "type": "basic",
         "rounds": int(Context.get_parameters("exit_round", 5)),
         "target_accuracy": 0.97,
         "epochs": int(Context.get_parameters("epochs", 5)),
         "batch_size": int(Context.get_parameters("batch_size", 32)),
         "optimizer": "SGD",
         "learning_rate": float(Context.get_parameters("learning_rate", 0.01)),
         # The machine learning model
         "model_name": "sdd_model",
         "momentum": 0.9,
         "weight_decay": 0.0
     }
示例#15
0
def image_process(line):
    file_path, label = line.split(',')
    original_dataset_url = Context.get_parameters('original_dataset_url')
    root_path = os.path.dirname(original_dataset_url)
    file_path = os.path.join(root_path, file_path)
    img = img_preprocessing.load_img(file_path).resize((128, 128))
    data = img_preprocessing.img_to_array(img) / 255.0
    label = [0, 1] if int(label) == 0 else [1, 0]
    data = np.array(data)
    label = np.array(label)
    return [data, label]
示例#16
0
 def __init__(
         self,
         aggregation: str,
         host: str = None,
         http_port: int = None,
         exit_round: int = 1,
         participants_count: int = 1,
         ws_size: int = 10 * 1024 * 1024):
     if not host:
         host = Context.get_parameters("AGG_BIND_IP", get_host_ip())
     if not http_port:
         http_port = int(Context.get_parameters("AGG_BIND_PORT", 7363))
     super(
         AggregationServer,
         self).__init__(
         servername=aggregation,
         host=host,
         http_port=http_port,
         ws_size=ws_size)
     self.aggregation = aggregation
     self.participants_count = participants_count
     self.exit_round = max(int(exit_round), 1)
     self.app = FastAPI(
         routes=[
             APIRoute(
                 f"/{aggregation}",
                 self.client_info,
                 response_class=JSONResponse,
             ),
             WebSocketRoute(
                 f"/{aggregation}",
                 BroadcastWs
             )
         ],
     )
     self.app.shutdown = False
示例#17
0
def _load_txt_dataset(dataset_url):
    # use original dataset url,
    # see https://github.com/kubeedge/sedna/issues/35
    original_dataset_url = Context.get_parameters('original_dataset_url')
    return os.path.join(os.path.dirname(original_dataset_url), dataset_url)
示例#18
0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from sedna.algorithms.aggregation import MistNet
from sedna.algorithms.client_choose import SimpleClientChoose
from sedna.common.config import Context
from sedna.core.federated_learning import FederatedLearningV2

simple_chooser = SimpleClientChoose(per_round=1)

# It has been determined that mistnet is required here.
mistnet = MistNet(cut_layer=Context.get_parameters("cut_layer"),
                  epsilon=Context.get_parameters("epsilon"))

# The function `get_transmitter_from_config()` returns an object instance.
s3_transmitter = FederatedLearningV2.get_transmitter_from_config()


class Dataset:
    def __init__(self) -> None:
        self.parameters = {
            "datasource": "YOLO",
            "data_params": "./coco128.yaml",
            # Where the dataset is located
            "data_path": "./data/COCO",
            "train_path": "./data/COCO/coco128/images/train2017/",
            "test_path": "./data/COCO/coco128/images/train2017/",
示例#19
0
def main():
    tf.set_random_seed(22)

    class_names = Context.get_parameters("class_names")

    # load dataset.
    train_dataset_url = BaseConfig.train_dataset_url
    train_data = TxtDataParse(data_type="train", func=_load_txt_dataset)
    train_data.parse(train_dataset_url, use_raw=True)

    # read parameters from deployment config.
    obj_threshold = Context.get_parameters("obj_threshold")
    nms_threshold = Context.get_parameters("nms_threshold")
    input_shape = Context.get_parameters("input_shape")
    epochs = Context.get_parameters('epochs')
    batch_size = Context.get_parameters('batch_size')

    tf.flags.DEFINE_string('train_url',
                           default=BaseConfig.model_url,
                           help='train url for model')
    tf.flags.DEFINE_string('log_url', default=None, help='log url for model')
    tf.flags.DEFINE_string('checkpoint_url',
                           default=None,
                           help='checkpoint url for model')
    tf.flags.DEFINE_string('model_name',
                           default=None,
                           help='url for train annotation files')
    tf.flags.DEFINE_list(
        'class_names',
        default=class_names.split(','),
        # 'helmet,helmet-on,person,helmet-off'
        help='label names for the training datasets')
    tf.flags.DEFINE_list('input_shape',
                         default=[int(x) for x in input_shape.split(',')],
                         help='input_shape')  # [352, 640]
    tf.flags.DEFINE_integer('max_epochs',
                            default=epochs,
                            help='training number of epochs')
    tf.flags.DEFINE_integer('batch_size',
                            default=batch_size,
                            help='training batch size')
    tf.flags.DEFINE_boolean('load_imagenet_weights',
                            default=False,
                            help='if load imagenet weights or not')
    tf.flags.DEFINE_string('inference_device',
                           default='GPU',
                           help='which type of device is used to do inference,'
                           ' only CPU, GPU or 310D')
    tf.flags.DEFINE_boolean('copy_to_local',
                            default=True,
                            help='if load imagenet weights or not')
    tf.flags.DEFINE_integer('num_gpus', default=1, help='use number of gpus')
    tf.flags.DEFINE_boolean('finetuning',
                            default=False,
                            help='use number of gpus')
    tf.flags.DEFINE_boolean('label_changed',
                            default=False,
                            help='whether number of labels is changed or not')
    tf.flags.DEFINE_string('learning_rate',
                           default='0.001',
                           help='learning rate to used for the optimizer')
    tf.flags.DEFINE_string('obj_threshold',
                           default=obj_threshold,
                           help='obj threshold')
    tf.flags.DEFINE_string('nms_threshold',
                           default=nms_threshold,
                           help='nms threshold')
    tf.flags.DEFINE_string('net_type',
                           default='resnet18',
                           help='resnet18 or resnet18_nas')
    tf.flags.DEFINE_string('nas_sequence',
                           default='64_1-2111-2-1112',
                           help='resnet18 or resnet18_nas')
    tf.flags.DEFINE_string('deploy_model_format',
                           default=None,
                           help='the format for the converted model')
    tf.flags.DEFINE_string('result_url',
                           default=None,
                           help='result url for training')

    incremental_instance = IncrementalLearning(estimator=Estimator)
    return incremental_instance.train(train_data=train_data,
                                      epochs=epochs,
                                      batch_size=batch_size,
                                      class_names=class_names,
                                      input_shape=input_shape,
                                      obj_threshold=obj_threshold,
                                      nms_threshold=nms_threshold)
示例#20
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker

from sedna.common.config import Context


SQLALCHEMY_DATABASE_URL = Context.get_parameters(
    "KB_URL", "sqlite:///lifelong_kb.sqlite3")

engine = create_engine(
    SQLALCHEMY_DATABASE_URL, encoding='utf-8',
    echo=True, connect_args={'check_same_thread': False}
)


SessionLocal = sessionmaker(
    bind=engine, autoflush=False, autocommit=False, expire_on_commit=True)

Base = declarative_base(bind=engine, name='Base')
示例#21
0
 def __init__(self, trainset=None, testset=None) -> None:
     self.customized = True
     self.trainset = tf.data.Dataset.from_tensor_slices((trainset.x, trainset.y))
     self.trainset = self.trainset.batch(int(Context.get_parameters("batch_size", 32)))
     self.testset = tf.data.Dataset.from_tensor_slices((testset.x, testset.y))
     self.testset = self.testset.batch(int(Context.get_parameters("batch_size", 32)))
示例#22
0
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import time
import warnings

import cv2
import numpy as np

from sedna.common.config import Context
from sedna.common.file_ops import FileOps
from sedna.core.incremental_learning import IncrementalLearning
from interface import Estimator

he_saved_url = Context.get_parameters("HE_SAVED_URL", '/tmp')
rsl_saved_url = Context.get_parameters("RESULT_SAVED_URL", '/tmp')
class_names = ['person', 'helmet', 'helmet_on', 'helmet_off']

FileOps.clean_folder([he_saved_url, rsl_saved_url], clean=False)


def draw_boxes(img, labels, scores, bboxes, class_names, colors):
    line_type = 2
    text_thickness = 1
    box_thickness = 1
    #  get color code
    colors = colors.split(",")
    colors_code = []
    for color in colors:
        if color == 'green':