Пример #1
0
def main(args):
    #### set up cfg ####
    # default cfg
    cfg = get_cfg()

    # add registered cfg
    cfg = build_config(cfg, args.config_name)
    cfg.setup(args)

    #### seed ####
    SEED = cfg.seed
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    #### start searching ####
    trainer = build_trainer(cfg)
    try:
        trainer.train(cfg.trainer.validate_always)
        if not cfg.trainer.validate_always:
            trainer.test()
    except (KeyboardInterrupt, ) as e:
        if isinstance(e, KeyboardInterrupt):
            print(f'Capture KeyboardInterrupt event ...')
        else:
            print(str(e))
    finally:
        trainer.save_cfg()
Пример #2
0
def persis_image_sendkafka(img: np.ndarray, args: dict, cfg=None):
    # logger.debug("===============")
    # for k, v in cfg.items():
    #     logger.debug(f"k is:{k}, v is :{v}")
    # logger.debug("===============")
    # encode jpeg frame and write to disk ...
    if img is None: return
    assert img.ndim > 2

    if img.ndim == 3:
        imgs = np.expand_dims(img, axis=0)
    else:
        imgs = img
    if cfg is None: cfg = get_cfg("configs/DECODER.yaml")
    server = cfg.KFK_SERVER_LIST
    username = cfg.KFK_CONSUMER_USER
    password = cfg.KFK_CONSUMER_PWD
    topic = cfg.KFK_TOPIC_CUSTOMERFLOW
    moutpoint = cfg.PICTURE_MOUNT_POINT  # "/dev/shm/PICTURE_MOUNT_POINT"
    localhostip = cfg.LOCAL_HOST_IP  # 10.10.117.131
    today = date.today().strftime('%Y%m%d')  # "20210303"
    try:
        busitype = args['busitype']
        urlkeylist = args['imgurl_key_list']
    except:
        busitype = random.choice(
            ["dustbin", "car", "bus", "track", "bicycle", "motocycle", "tricycle", "person", "face", "alien"])
        urlkeylist = ['img_url' for _ in range(len(imgs))]
    nj = NvJpeg()
    # producer = KafkaProducer(bootstrap_servers=server, security_protocol="SASL_PLAINTEXT", sasl_mechanism='PLAIN',
    #                          sasl_plain_username=username, sasl_plain_password=password)
    producer = ProducerWarp(server, username, password)
    data = None
    with open('test.json', "rb") as f:
        data = f.read()
    djson = json.loads(data)

    # merge djson from args
    for k, v in args.items():
        try:
            djson[k] = v
        except:
            pass

    # merge
    os.makedirs(os.path.join(moutpoint, busitype, today), exist_ok=True)
    for iiii, frame in enumerate(imgs):
        assert frame.ndim == 3
        uudi_tmp = uuid.uuid4().hex
        with open(os.path.join(moutpoint, busitype, today, f"{uudi_tmp}.jpg"), "wb") as fid:
            frame_jpg = nj.encode(frame)
            fid.write(frame_jpg)
        # djson['image_url'] = os.path.join(localhostip, busitype, today, f"{uudi_tmp}.jpg")
        djson[urlkeylist[iiii]] = os.path.join(localhostip, busitype, today, f"{uudi_tmp}.jpg")

        dd = json.dumps(djson).encode('utf-8')
        # producer.send(topic, value=dd).add_callback(on_send_success).add_errback(on_send_error)
        producer.produce_business(dd, topic)
        print('.', end='')
    sys.stdout.flush()
Пример #3
0
def setup_cfg(args):
    cfg = get_cfg()
    cfg = build_config(cfg, args.config_name)
    cfg.merge_from_file((args.config_file))
    cfg.merge_from_list(args.opts)
    if cfg.model.resume_path:
        cfg.logger.path = os.path.dirname(cfg.model.resume_path)
    else:
        index = 0
        path = os.path.dirname(args.arc_path) + '_retrain_{}'.format(index)
        while os.path.exists(path):
            index += 1
            path = os.path.dirname(args.arc_path) + '_retrain_{}'.format(index)
        cfg.logger.path = path
    cfg.logger.log_file = os.path.join(cfg.logger.path, 'log_retrain.txt')
    os.makedirs(cfg.logger.path, exist_ok=True)
    cfg.freeze()
    SEED = cfg.seed
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    return cfg
from fvcore.common.timer import Timer
import numpy as np
import tritonclient.http as httpclient
import tritonclient.utils.cuda_shared_memory as cudashm
from tritonclient import utils
from configs import get_cfg, get_logger
import uuid
import shared_numpy as snp

from worker_decoder import Worker_decoder
from worker_detector import Worker_detector

logger = get_logger()
# init default config and merge from base.yaml
# default values configs/__init__.py
cfg = get_cfg("configs/DECODER.yaml")
# monkey.patch_all()
model_w = 640
model_h = 384
model_c = 3
input_w = 1280
input_h = 720

# model_c = 1


def start_work(verbose=False):
    logger.debug("===============")
    for k, v in cfg.items():
        logger.debug(f"k is:{k}, v is :{v}")
    logger.debug("===============")