예제 #1
0
파일: driver.py 프로젝트: necla-ml/ML-WS
 def handler(sig, frame):
     logging.warning(f"Interrupted or stopped by {signal.Signals(sig).name}")
     nonlocal signaled
     signaled = True
     # XXX: avoid stale TCP connections
     broker.disconnect()
     sys.exit(signal.Signals(sig).value)
예제 #2
0
파일: av.py 프로젝트: necla-ml/ML
def codec(fmt):
    '''Returns registered codec name in FFMPEG and fourCC.
    
    Args:
        fmt: informal codec format
    
    Returns:
        codec: registered codec name in FFMPEG
        fourcc: four CC of the codec
    '''

    lower = fmt.lower()
    if 'avc1' in lower or '264' in lower:
        return 'h264', CODECS['h264']
    elif 'hevc' in lower or '265' in lower:
        return 'hevc', CODECS['h265']
    elif 'mpeg4' in lower or 'mp42' in lower:
        return 'mpeg4', CODECS['mpeg4']
    elif 'jpg' in lower or 'jpeg' in lower:
        return 'mjpeg', CODECS['mjpeg']
    elif 'alaw' in lower:
        return 'pcm_alaw', CODECS['pcm_alaw']
    elif 'ulaw' in lower:
        return 'pcm_mulaw', CODECS['pcm_mulaw']
    elif 'yuyv' in lower:
        return 'yuyv', CODECS['yuyv']
    else:
        logging.warning(f"Unknown codec format: {fmt}")
        return None, None
예제 #3
0
 def close(self):
     logging.info(f"CLOSE {self.name}")
     self.loop.quit()
     state = self.pipeline.set_state(Gst.State.NULL)
     if state != Gst.StateChangeReturn.SUCCESS:
         logging.warning('GST state change to NULL failed')
     self.join(timeout=None)
예제 #4
0
파일: decoder.py 프로젝트: necla-ml/ML
 def __next__(self):
     if hasattr(self, 'close_connection') and self.close_connection:
         logging.warning(f"Connection closed")
         raise StopIteration
     else:
         req = self.parse()
         if req is None:
             logging.warning(f"No requests")
             raise StopIteration
         return req
예제 #5
0
def upload_s3(path, bucket, key):
    '''
    Args:
        path(str): path to the file to upload
        bucket(str): S3 bucket name
        key(str): key to upload to the bucket where the ending '/' matters
    '''
    try:
        import botocore, boto3
        from botocore.exceptions import ClientError
    except ImportError as e:
        logging.warning(
            f'botocore and boto3 are required to download from S3: {e}')
        return False
    else:
        # XXX Amazon S3 supports buckets and objects, and there is no hierarchy.
        path = Path(path)
        s3 = boto3.resource('s3').meta.client
        if not path.is_file():
            logging.error(f"{path} not exist or not a file to upload")
            return False
        total = 0
        start = time()

        def callback(bytes):
            nonlocal total
            total += bytes
            elapse = time() - start
            if total < 1024:
                print(
                    f"\rUploaded {total:4d} bytes at {total / elapse:.2f} bytes/s",
                    end='')
            elif total < 1024**2:
                KB = total / 1024
                print(f"\rUploaded {KB:4.2f}KB at {KB/elapse:4.2f} KB/s",
                      end='')
            else:
                MB = total / 1024**2
                print(f"\rUploaded {MB:8.2f}MB at {MB/elapse:6.2f} MB/s",
                      end='')
            sys.stdout.flush()

        try:
            print(path, bucket, key)
            s3.upload_file(str(path), bucket, key, Callback=callback)
        except ClientError as e:
            print()
            logging.error(
                f"Failed to upload {path} to s3://{bucket}/{key}: {e}")
        else:
            print()
            logging.info(f"Succeeded to upload {path} to s3://{bucket}/{key}")
        return True
예제 #6
0
    def __init__(
        self,
        split,
        tokenization,
        path="data/flickr",
        max_tokens=80,
        max_entities=16,
        max_rois=100,
        transform=None,
        target_transform=None,
    ):
        # XXX Use Entlities from BAN instead
        import h5py
        path = Path(path)

        # ROI features
        h5 = path / f"{split}.hdf5"
        imgid2idx = path / f"{split}_imgid2idx.pkl"
        if not h5.exists() or not imgid2idx.exists():
            logging.warning(
                f"{h5} or {imgid2idx} not exist, extracting features on the fly..."
            )
            prefix = path / self.res["features"]["cfg"]
            tsvs = [prefix / tsv for tsv in self.res["features"][split]]
            logger.info(f"Extracting ROI features from {prefix}")
            extract(split, tsvs, path)

        logger.info(f"Loading image/RoI features from {h5}")
        self.imgid2idx = pickle.load(open(imgid2idx, "rb"))
        with h5py.File(h5, "r") as h5:
            self.offsets = th.from_numpy(np.array(h5.get("pos_boxes")))
            self.features = th.from_numpy(np.array(h5.get("image_features")))
            self.spatials = th.from_numpy(np.array(h5.get("spatial_features")))
            self.rois = th.from_numpy(np.array(h5.get("image_bb")))

        # Entities and ground truth bboxes
        #   pos_box start offset => annotation
        self.max_tokens = max_tokens
        self.max_entities = max_entities
        self.max_rois = max_rois
        self.tokenization = tokenization
        self.annotations = _load_flickr30k(
            split,
            path,
            self.imgid2idx,
            self.offsets,
            self.rois  #, self.tokenize, self.tensorize
        )

        if tokenization in ['bert', 'wordpiece']:
            from ml.nlp import bert
            bert.setup()
예제 #7
0
파일: youtube.py 프로젝트: necla-ml/ML-WS
def yt_hls_url(url, *args, file_name='video.h264', **kwargs):
    """
    Get hls url if live stream else download video and transcode
    params: 
        url - youtube url to download video from
        start - start video from this timestamp(00:00:15)
        end - end video after this timestamp(00:00:10)
    Returns:
        video path or url(str)
    """
    res = None
    start = kwargs.pop('start', None)
    # NOTE: enforce 5 min limit on non-live youtube videos
    end = kwargs.pop('end', None)
    if not end:
        end = '00:05:00'
    try:
        # video is live --> get hls url and stream
        res = subprocess.run(['youtube-dl', '-f', '95', '-g', url], stdout=subprocess.PIPE) \
            .stdout.decode('utf-8') \
            .strip()
        if not res:
            logging.warning(
                f"video is not live --> transcode to h264 and stream")
            url = subprocess.run(
                ['youtube-dl', '-f', 'best', '-g', url],
                stdout=subprocess.PIPE).stdout.decode('utf-8').strip()
            cmd = f'ffmpeg '
            if start:
                cmd += f'-ss {start} '
            cmd += f'-i {url} '
            if end:
                cmd += f'-t {end} '
            cmd += f'-an \
                -s 1280x720 \
                -g 15 \
                -r 15 \
                -b 2M \
                -vcodec h264 \
                -bf 0 \
                -bsf h264_mp4toannexb \
                -y {file_name}'

            cmd = shlex.split(cmd)
            output = subprocess.call(cmd)
            res = os.path.abspath(file_name)
    except Exception as e:
        # TODO: handle transcoding error with proper errno key
        #sys.exit(errno.)
        logging.info(e)

    return res
예제 #8
0
def download_gdrive(id='1mM8aZJlWTxOg7BZJvNUMrTnA2AbeCVzS',
                    path='/tmp/yolov5x.pt',
                    force=False):
    # https://gist.github.com/tanaikech/f0f2d122e05bf5f971611258c22c110f
    # Downloads a file from Google Drive, accepting presented query
    # from utils.google_utils import *; gdrive_download()
    import time
    t = time.time()

    if os.path.exists(path):
        if force:
            os.remove(path)
            logging.warning(f"Removed existing download: {path}")
        else:
            logging.warning(
                f"Download exists: {path}, specify force=True to remove if necessary"
            )
            return 0

    logging.info(
        f'Downloading https://drive.google.com/uc?export=download&id={id} to {path}...'
    )
    os.remove('cookie') if os.path.exists('cookie') else None

    # Attempt file download
    os.system(
        f"curl -c ./cookie -s -L \'https://drive.google.com/uc?export=download&id={id}\' > /dev/null"
    )
    if os.path.exists('cookie'):  # large file
        # s = "curl -Lb ./cookie \"https://drive.google.com/uc?export=download&confirm=`awk '/download/ {print $NF}' ./cookie`&id=%s\" -o %s" % (id, path)
        s = f"curl -Lb ./cookie \"https://drive.google.com/uc?export=download&confirm=`awk '/download/ {{print $NF}}' ./cookie`&id={id}\" -o {path}"
    else:  # small file
        s = f"curl -s -L -o {path} 'https://drive.google.com/uc?export=download&id={id}'"
    r = os.system(s)  # execute, capture return values
    os.remove('cookie') if os.path.exists('cookie') else None

    # Error check
    if r != 0:
        os.remove(path) if os.path.exists(path) else None  # remove partial
        logging.error(f'Failed to download to {path}'
                      )  # raise Exception('Download error')
        return r
    '''
    # Unzip if archive
    if path.endswith('.zip'):
        logging.info('Unzipping... ')
        os.system('unzip -q %s' % path)  # unzip
        os.remove(path)  # remove zip to free space
    '''
    logging.info(f'Done in {time.time() - t:.1f}s')
    return r
예제 #9
0
파일: utils.py 프로젝트: necla-ml/ML
def get_calibration_files(calibration_data,
                          max_calibration_size=None,
                          allowed_extensions=(".jpeg", ".jpg", ".png")):
    """Returns a list of all filenames ending with `allowed_extensions` found in the `calibration_data` directory.

    Parameters
    ----------
    calibration_data: str
        Path to directory containing desired files.
    max_calibration_size: int
        Max number of files to use for calibration. If calibration_data contains more than this number,
        a random sample of size max_calibration_size will be returned instead. If None, all samples will be used.

    Returns
    -------
    calibration_files: List[str]
         List of filenames contained in the `calibration_data` directory ending with `allowed_extensions`.
    """

    logging.info(
        "Collecting calibration files from: {:}".format(calibration_data))
    calibration_files = [
        path for path in glob.iglob(os.path.join(calibration_data, "**"),
                                    recursive=True)
        if os.path.isfile(path) and path.lower().endswith(allowed_extensions)
    ]
    logging.info("Number of Calibration Files found: {:}".format(
        len(calibration_files)))

    if len(calibration_files) == 0:
        raise Exception(
            "ERROR: Calibration data path [{:}] contains no files!".format(
                calibration_data))

    if max_calibration_size:
        if len(calibration_files) > max_calibration_size:
            logging.warning(
                "Capping number of calibration images to max_calibration_size: {:}"
                .format(max_calibration_size))
            random.seed(42)  # Set seed for reproducibility
            calibration_files = random.sample(calibration_files,
                                              max_calibration_size)

    return calibration_files
예제 #10
0
    def send_records(self, records, attempt=0):
        """Send records to the Kinesis stream.
        Falied records are sent again with an exponential backoff decay.
        Parameters
        ----------
        records : array
            Array of formated records to send.
        attempt: int
            Number of times the records have been sent without success.
        """

        # If we already tried more times than we wanted, save to a file
        if attempt > self.max_retries:
            logging.warning(
                f'[{self._name}] Writing {len(records)} records to file')
            with open('failed_records.dlq', 'ab') as f:
                for r in records:
                    f.write(r.get('Data'))
            return

        # Sleep before retrying
        if attempt:
            time.sleep(2**attempt * .1)

        try:
            response = self.kinesis_client.put_records(
                StreamName=self.stream_name, Records=records)
        except Exception as e:
            logging.error(f'[{self._name}]: {e}')
            raise e
        else:
            failed_record_count = response['FailedRecordCount']

            # Grab failed records
            if failed_record_count:
                logging.warning(f'[{self._name}] Retrying failed records')
                failed_records = []
                for i, record in enumerate(response['Records']):
                    if record.get('ErrorCode'):
                        failed_records.append(records[i])

                # Recursive call
                attempt += 1
                self.send_records(failed_records, attempt=attempt)
예제 #11
0
def test_single_session(ip, port, user, passwd, areas, FPS):
    area = areas[1]
    fps = FPS[area]
    decoding = not True
    source = AVSource.create(url(ip, port), user=user, passwd=passwd)
    sessions = source.open(area,
                           'Original',
                           fps=fps,
                           decoding=decoding,
                           exact=True,
                           with_audio=True)
    assert len(sessions) == 1

    session = sessions[0]
    video = session['video']
    logging.info(f"Session: \n{session}")

    total = 300
    X = sys.x_available()
    for i in range(total):
        res = source.read(session, media=None)
        if res is None:
            logging.warning(f"Skipped invalid frame")
            continue

        m, media, frame = res
        if decoding:
            duration = float(
                media['duration'] * media['timbe_base']
            ) if m == 'video' else frame.size / 2 / media['sample_rate']
            logging.info(
                f"{m}[{media['count']}]: {frame.shape} of {frame.dtype} with duraton {duration:.3f}s at {media['time']:.3f}s, now={time():.3f}s"
            )
        else:
            duration = float(
                media['duration'] * media['time_base']
            ) if m == 'video' else frame.size / media['sample_rate']
            logging.info(
                f"{m}[{media['count']}]: {media['keyframe'] and 'key ' or ''}({frame.size} bytes) with duration {duration:.3f}s at {media['time']:.3f}s, now={time():.3f}s"
            )
    logging.info(f"{m}/{video['format']} stream FPS={video['fps_rt']:.2f}")
    source.close(session)
    assert not session
예제 #12
0
파일: avsource.py 프로젝트: necla-ml/ML-WS
    def get(self, session, key, media='video'):
        if media == 'video' and media in session:
            video = session[media]
            if key == av.VIDEO_IO_FLAGS.CAP_PROP_FOURCC:
                return video['codec'] and av.avcodec(
                    video['codec'].name)[1] or None
            elif key == av.VIDEO_IO_FLAGS.CAP_PROP_FRAME_WIDTH:
                return video['width']
            elif key == av.VIDEO_IO_FLAGS.CAP_PROP_FRAME_HEIGHT:
                return video['height']
            elif key == av.VIDEO_IO_FLAGS.CAP_PROP_FPS:
                return video['fps']
            elif key == av.VIDEO_IO_FLAGS.CAP_PROP_POS_MSEC:
                return video['time'] * 1000
            elif key == av.VIDEO_IO_FLAGS.CAP_PROP_BUFFERSIZE:
                return 0

        logging.warning(f"Unknown key to get: {key} from {media}")
        return None
예제 #13
0
def parse(url):
    spec = urlparse(url)
    if spec.scheme == 's3':
        key = spec.path[1:]
        return dict(
            scheme='s3://',
            bucket=spec.netloc,
            key=spec.path[1:],
            name=os.path.basename(key),
        )
    elif spec.netloc == 'github.com':
        # https://github.com/ultralytics/yolov5/releases/download/v3.0/yolov5x.pt
        path = Path(spec.path)
        owner, project, tag, name = path.parents[4].name, path.parents[
            3].name, path.parent.name, path.name
        return dict(owner=owner, project=project, tag=tag, name=name)
    else:
        logging.warning('Unknown url spec={url}')
        return None
예제 #14
0
파일: calibrator.py 프로젝트: necla-ml/ML
    def __init__(self,
                 batch_size=32,
                 inputs=[],
                 cache=None,
                 calibration_files=[],
                 max_calib_data=512,
                 preprocess_func=None,
                 algorithm=trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2,
                 device=torch.cuda.default_stream().device):
        super().__init__()

        self.inputs = inputs
        if not inputs:
            raise ValueError(
                'Input shapes is required to generate calibration dataset')

        # unique cache file name in case mutliple engines are built in parallel
        self.cache = cache or f'{uuid4().hex}.cache'
        self.batch_size = batch_size
        self.max_calib_data = max_calib_data
        self.algorithm = algorithm
        self.files = calibration_files
        self.buffers = [
            torch.empty((batch_size, *input_shape),
                        dtype=torch.float32,
                        device=device) for input_shape in inputs
        ]

        # Pad the list so it is a multiple of batch_size
        if self.files and len(self.files) % self.batch_size != 0:
            logging.info(
                "Padding number of calibration files to be a multiple of batch_size {:}"
                .format(self.batch_size))
            self.files += calibration_files[(len(calibration_files) %
                                             self.batch_size):self.batch_size]

        if not preprocess_func:
            logging.warning(
                'default preprocessing applied to convert input to RGB tensor followed by ImageNet resize, crop and normaliztion.'
            )
        self.preprocess_func = preprocess_func or preprocessor()
        self.batches = self.load_batches()
예제 #15
0
 def update(self, instance):
     """Update track statistics and history given a new observation or last tracked.
     instance(xyxysc or xyxy): up to date detection or predicted position
     """
     instance = instance.cpu()
     info = len(instance)
     score = self.score
     if info > 4:
         # detection
         self.predicted = False
         score, cls = instance[4:6]
         score, cls = score.item(), cls.item()
         xyxysc = instance[:6]
         if cls != self.cls:
             logging.warning(f"Inconsistent detection class from {self.cls} to {cls}")
             xyxysc = instance.clone()
             xyxysc[-2] = score = self.score
             xyxysc[-1] = self.cls
     elif info == 4:
         # prediction
         self.predicted = True
         # logging.warning(f"track[{self.tid}] predicted={instance.round().int().tolist()}({self.score:.2f})")
         xyxysc = th.cat((instance, th.Tensor([self.score, self.cls])))
     
     age = self.age
     self.age += 1
     self.score = (age * self.score + score) / self.age
     if age == 0:
         self.origin = (xyxysc[2:4] + xyxysc[:2]) / 2
         self.velocity = th.zeros_like(self.origin)
     else:
         last = self.last
         prev = (last[2:4] + last[:2]) / 2
         center = (xyxysc[2:4] + xyxysc[:2]) / 2
         velocity = center - prev
         # FIXME moving average could be too slow
         # self.velocity = ((age - 1) * self.velocity + velocity) / age
         self.velocity = (self.velocity + velocity) / 2
         # print(f"snapshot[{self.tid}]:", prev.tolist(), center.tolist(), velocity.tolist(), self.velocity.tolist())
     self.history.append(xyxysc)
     if len(self.history) > self.length:
         self.history.pop(0)
예제 #16
0
파일: utils.py 프로젝트: necla-ml/ML-WS
def decrypt(private_key, value):
    """
    Decrypt value with the private key
    Params: 
        private_key: string encoded or string 
        value: value to be decrypted by the private key(str or byte)
    Returns: 
        decrypted value(string)
    """
    cipher_key = private_key.encode() if not isinstance(private_key, bytes) else private_key # encode private key
    # init fernet
    cipher = Fernet(cipher_key) 
    value = str(value).encode() if not isinstance(value, bytes) else value # encode value
    decrypted_value = None
    try:
        decrypted_value = cipher.decrypt(value).decode("utf-8")
        decrypted_value = json.loads(decrypted_value)
    except Exception as e:
        logging.warning(f"Not in json, assume string value")
    return decrypted_value
예제 #17
0
파일: awscam.py 프로젝트: necla-ml/ML-WS
    def open(self):
        if self.stream is not None:
            logging.warning("Already opened")
            return False

        self.encoder = av.open(ENCODERS[self.ch])
        self.video = self.encoder.streams[0]
        self.codec = self.video.codec_context
        self.stream = self.encoder.demux()

        self.duration = float(self.codec.time_base *
                              self.codec.ticks_per_frame)
        self.fps = 1 / self.duration  # nominal FPS
        self.rate = float(self.codec.rate)  # average FPS

        self.started = False
        self.start = None
        self.time = None
        self.frames = 0
        return True
예제 #18
0
def test_streaming_all(credentials, duration=5):
    ip = credentials['ip']
    port = credentials['port']
    user = credentials['username']
    passwd = credentials['passwd']
    nvr = NVR.create(ip, port, user=user, passwd=passwd)
    assert isinstance(nvr, Titan8040R)
    nvr.connect()
    logging.info(
        f"##### Probing all streaming profiles for {duration} frames #####")
    for cam in nvr:
        if cam['area'] == 'First Floor PTZ':
            logging.warning(f"Skipping 'First Floor PTZ'")
            continue

        cfgs = nvr.query(area=cam['area'], exact=True)
        assert len(cfgs) == 1
        cfg = cfgs[0]
        assert len(cfg) == 3
        startStreaming(nvr, cfg, duration)
        print()
예제 #19
0
파일: awscam.py 프로젝트: necla-ml/ML-WS
    def read(self):
        if self.stream is None:
            logging.warning('No stream open')
            return None

        if not self.started:
            # Read until the first fresh key frame w.r.t. the specified FPS
            prev = time()
            for i, frame in enumerate(self.stream):
                now = time()
                if (now - prev
                    ) < 9 * self.duration / 10 or not frame.is_keyframe:
                    logging.warning(
                        f"frame[{i}] Skipped a buffered stale {frame.is_keyframe and 'key ' or '    '}frame for short duration of {now - prev:.3f}s < {self.duration:.3f}s"
                    )
                    prev = now
                    continue
                else:
                    self.started = True
                    self.time = self.start = now
                    logging.info(f"frame[{i}] First fresh key frame")
                    break
        else:
            # Compensate unexpected encoder latency
            frame = next(self.stream)
            now = time()
            prev = self.time
            self.time += self.duration
            duration = float(frame.time_base * frame.duration)
            self.duration = duration + (now - self.time) / 2
            '''            
            actual = time() - self.time
            if actual > duration:
                self.time += (duration + actual) / 2
                logging.warning(f"Unexpected frame latency: {actual:.6f}s > {duration:.6f}s")
            else:
                self.time += duration
            '''
        self.frames += 1
        return frame
예제 #20
0
파일: avsource.py 프로젝트: necla-ml/ML-WS
    def set(self, session, key, value, media='video'):
        if media == 'video' and media in session:
            video = session[media]
            stream = video['stream']
            if not hasattr(stream, 'set'):
                logging.warning(f"Source stream property cannot be changed")
                return False

            if key == av.VIDEO_IO_FLAGS.CAP_PROP_FOURCC:
                fmt, fourcc = av.codec(value)
                if stream.set(av.VIDEO_IO_FLAGS.CAP_PROP_FOURCC, fourcc):
                    video['format'] = fmt
                    return True
                else:
                    logging.warning(
                        f"Failed to set video source CAP_PROP_FOURCC to {fmt}({value})"
                    )
                    return False
            elif stream.set(key, value):
                res = stream.get(key)
                if key == av.VIDEO_IO_FLAGS.CAP_PROP_FPS:
                    video['fps'] = int(res)
                    logging.warning(
                        f"Set video source CAP_PROP_FPS to {value}({int(res)})"
                    )
                elif key == av.VIDEO_IO_FLAGS.CAP_PROP_FRAME_WIDTH:
                    video['width'] = int(res)
                    video['height'] = int(stream.get(av.CAP_PROP_FRAME_HEIGHT))
                    logging.info(
                        f"Set video source CAP_PROP_FRAME_WIDTH to {value}({int(res)})"
                    )
                elif key == av.VIDEO_IO_FLAGS.CAP_PROP_FRAME_HEIGHT:
                    video['height'] = int(res)
                    video['width'] = int(
                        stream.get(av.VIDEO_IO_FLAGS.CAP_PROP_FRAME_WIDTH))
                    logging.info(
                        f"Set video source CAP_PROP_FRAME_HEIGHT to {value}({int(res)})"
                    )
                return True

        logging.warning(f"Unsupported {media} property to set")
        return False
예제 #21
0
파일: avsource.py 프로젝트: necla-ml/ML-WS
def openAV(src, decoding=False, with_audio=False, **kwargs):
    try:
        format = None
        options = None
        fps = float(kwargs.get('fps', 10))
        if str(src).startswith('rtsp'):
            # ffmpeg RTSP options:
            # rtsp_transport: tcp, http, udp_multicast, udp
            # rtsp_flags: prefer_tcp, filter_src, listen, none
            # allowed_media_types: video, audio, data
            # stimeout: socket TCP I/O timeout in us
            # RTSP/HTTP required for VPN but not necessarily supported
            options = dict(
                rtsp_transport=kwargs.get('rtsp_transport', 'tcp'),
                rtsp_flags='prefer_tcp',
                stimeout=kwargs.get('stimeout',
                                    '5000000'))  # in case of network down

            # NOTE: retry with different rtsp transport types if unspecified
            if options and options.get('rtsp_transport', None):
                source = None
                for transport in ['tcp', 'http']:
                    try:
                        options['rtsp_transport'] = transport
                        source = av.open(src,
                                         format=format,
                                         options=options,
                                         timeout=(15, 5))
                    except Exception as e:
                        logging.warning(
                            f'Failed with rtsp_transport={transport}: {e}')
                    else:
                        options['rtsp_transport'] = transport
                        break
                assert source is not None, f"Failed to open RTSP source over TCP/HTTP"
            else:
                source = av.open(src,
                                 format=format,
                                 options=options,
                                 timeout=(15, 5))
        else:
            if isinstance(src, int) or (isinstance(src, str)
                                        and src.startswith('/dev/video')):
                # XXX webcam: high FPS with MJPG
                import platform
                system = platform.system()
                resolution = av.resolution_str(
                    *kwargs.get('resolution', ['720p']))
                options = {
                    'framerate': str(fps),
                    'video_size': resolution,
                    'input_format': 'mjpeg'
                }
                decoding = True
                if system == 'Darwin':
                    src = str(src)
                    format = 'avfoundation'
                elif system == 'Linux':
                    src = f"/dev/video{src}" if isinstance(src, int) else src
                else:
                    raise ValueError(f"Webcam unsupported on {system}")
            source = av.open(src, format=format, options=options)

        # timeout: maximum timeout (in secs) to wait for incoming connections and soket reading
        # XXX HLS connection potential time out for taking more than 5s
        logging.info(
            f"av.open({src}, format={format}, options={options}, timeout=(15, 5))"
        )
    except Exception as e:
        logging.error(e)
        raise e
    else:
        '''
        H.264 NALU formats:
        Annex b.: 
            RTSP/RTP: rtsp, 'RTSP input', set()
            bitstream: h264, 'raw H.264 video', {'h26l', 'h264', '264', 'avc'}
            webcam: /dev/videoX, ...
            DeepLens: /opt/.../...out
            NUUO/NVR: N/A
        AVCC:
            avi: avi, 'AVI (Audio Video Interleaved)', {'avi'}
            mp4: 'mov,mp4,m4a,3gp,3g2,mj2', 'QuickTime / MOV', {'m4a', 'mov', 'mp4', 'mj2', '3gp', '3g2'}
            webm/mkv: 'matroska,webm', 'Matroska / WebM', {'mks', 'mka', 'mkv', 'mk3d'}
            DASH/KVS(StreamBody): file-like obj
        '''
        now = time.time()
        start_time = source.start_time / 1e6  # us
        relative = abs(start_time -
                       now) > 60 * 60 * 24 * 30  # Too small to be absolute
        rt = not (isinstance(src, str) and os.path.isfile(src)
                  )  # regular file or not
        if rt:
            logging.info(f"Assume real-time source: {src}")
        else:
            logging.info(f"Simulating local source as real-time: {src}")

        # XXX start_time may be negative (webcam), zero if unavailable, or a small logical timestamp
        session = dict(
            src=src,
            streams=source,
            format=source.format.name,
            decoding=decoding,
            start=relative and now or start_time,
            rt=rt,
        )
        session_start_local = strftime('%X', localtime(session['start']))
        source_start_local = strftime('%X', localtime(start_time))
        logging.info(
            f"Session start: {session['start']:.3f}s({session_start_local}), source start: {start_time:.3f}s({source_start_local})"
        )
        if source.streams.video:
            # FIXME RTSP FPS might be unavailable or incorrectly set
            video0 = source.streams.video[0]
            codec = video0.codec_context
            FPS = 1 / (codec.time_base * codec.ticks_per_frame)
            fps = FPS > 60 and (codec.framerate
                                and float(codec.framerate)) or fps
            session['video'] = dict(
                stream=source.demux(video=0),
                start=video0.start_time,  # same as 1st frame in pts
                codec=codec,
                format=video0.name,
                width=video0.width,
                height=video0.height,
                fps=fps,
                count=0,
                time=0,  # pts in secs
                duration=None,  # frame duration in secs
                drifting=False,
                adaptive=kwargs.get('adaptive', True),
                workaround=kwargs.get('workaround', True),
                thresholds=dict(drifting=10, ),
                prev=None,
            )
            logging.info(
                f"codec.framerate={codec.framerate}, codec.time_base={codec.time_base}, codec.ticks_per_frame={codec.ticks_per_frame}, fps={session['video']['fps']}, FPS={FPS}"
            )
        if source.streams.audio:
            audio0 = source.streams.audio[0]
            codec = audio0.codec_context
            logging.warning(f"No audio streaming supported yet")
            '''
            if codec.name == 'aac':
                logging.warning(f"AAC is not supported yet")
            else:
                session['audio'] = dict(
                    stream=decoding and source.decode(audio=0) or source.demux(audio=0),
                    start=audio0.start_time,        # same as 1st frame in pts
                    format=audio0.name,
                    codec=codec,
                    sample_rate=codec.sample_rate,
                    channels=len(codec.layout.channels),
                    count=0,
                    time=0,
                )
            '''
        return session
예제 #22
0
def download_s3(bucket, key, path=None, progress=True):
    '''
    Args:
        bucket(str): S3 bucket name
        key(str): path to a file to download in the bucket
    Kwargs:
        path(str): directory to save the downloaded file named by the key or the target path to save
    '''
    try:
        import botocore, boto3
        from botocore.exceptions import ClientError
    except ImportError as e:
        logging.warning(
            f'botocore and boto3 are required to download from S3: {e}')
        return False
    else:
        s3 = boto3.client(
            's3', config=botocore.client.Config(max_pool_connections=50))
        path = Path(path or '.')
        if path.is_dir():
            path /= Path(key).name

        total = 0
        start = time()

        def callback(bytes):
            nonlocal total
            total += bytes
            elapse = time() - start
            if total < 1024:
                print(
                    f"\rDownloaded {total:4d} bytes at {total / elapse:.2f} bytes/s",
                    end='')
            elif total < 1024**2:
                KB = total / 1024
                print(f"\rDownloaded {KB:4.2f}KB at {KB/elapse:4.2f} KB/s",
                      end='')
            else:
                MB = total / 1024**2
                print(f"\rDownloaded {MB:8.2f}MB at {MB/elapse:6.2f} MB/s",
                      end='')
            sys.stdout.flush()

        try:
            import tempfile
            tmp = tempfile.NamedTemporaryFile(delete=False, dir=path.parent)
            s3.download_file(bucket,
                             key,
                             tmp.name,
                             Callback=progress and callback or None)
        except ClientError as e:
            print()
            logging.error(
                f"Failed to download s3://{bucket}/{key} to {path}: {e}")
            return False
        else:
            from ml import shutil
            shutil.move(tmp.name, path)
            print()
            logging.info(
                f"Succeeded to download s3://{bucket}/{key} to {path}")
            return True
예제 #23
0
파일: __init__.py 프로젝트: necla-ml/ML
from .backend import *

try:
    from torch.distributed import *
except ImportError as e:
    from ml import logging
    logging.warning(f"No pytorch installation for distributed execution")
예제 #24
0
def make_dataset_yolo5():
    '''Compose a dataset in YOLO5/COCO format out of one or more supported datasets.
    Usage:
        make_dataset_yolo5 coco/ SKU110K/ -o Retail81 --splits val train
    '''
    parser = argparse.ArgumentParser(
        "Compose a dataset in YOLO5 format from one or more supported datasets"
    )
    parser.add_argument('sources',
                        nargs='+',
                        help='One or more paths to supported source datasets')
    parser.add_argument('-o',
                        '--output',
                        required=True,
                        help='Composed output dataset path')
    parser.add_argument('--splits',
                        nargs='+',
                        default=['val'],
                        help='Dataset split(s) to convert')
    cfg = parser.parse_args()

    dataset = Path(cfg.output)
    dataset.mkdir(parents=True, exist_ok=True)
    images = dataset / 'images'
    labels = dataset / 'labels'
    images.mkdir(exist_ok=True)
    labels.mkdir(exist_ok=True)
    logging.info(
        f"Composing dataset={dataset.name}, splits={cfg.splits}, path={dataset}"
    )

    splits = dict(
        train=[],
        val=[],
        test=[],
    )
    logging.info(f"Collecting source dataset splits")
    from ml.vision import datasets
    for src in cfg.sources:
        src = Path(src)
        if hasattr(datasets, src.name.lower()):
            ds = getattr(datasets, src.name.lower())
        else:
            raise ValueError(f"Unsupported source dataset f'{src}'")
        for split in cfg.splits:
            splits[split].append(str(src / ds.SPLITS[split]))

    for split, split_files in splits.items():
        if not split_files:
            continue
        t = time.time()
        files = '\n'.join(split_files)
        logging.info(f"Working on {split} split from \n{files}")
        paths = []
        for file in split_files:
            entries = open(file).read().splitlines()
            paths.extend(entries)
            logging.info(f"Included {len(entries)} entries from {file}")
        with open(dataset / f"{split}.txt", 'w') as sf:
            for path in paths:
                path = Path(path)
                img_path = images / path.name
                label_path = labels / f"{path.stem}.txt"
                if img_path.is_symlink():
                    img_path.unlink()
                    logging.warning(f"Removed existing {img_path}")
                if label_path.is_symlink():
                    label_path.unlink()
                    logging.warning(f"Removed existing {label_path}")
                img_path.symlink_to(path)
                print(img_path, '->', f"{path}")
                parent = str(path.parent).replace('images', 'labels')
                label_path.symlink_to(f"{parent}/{path.stem}.txt")
                print(label_path, '->', f"{parent}/{path.stem}.txt")
                print(img_path.resolve(), file=sf)
        t = time.time() - t
        logging.info(
            f"Processed and saved {len(paths)} entries to {sf.name} in {t:.3}s"
        )
예제 #25
0
파일: h264.py 프로젝트: necla-ml/ML
def NALUParser(bitstream, workaround=False):
    '''
    Args:
        bitstream(bytes-like): a writable bytes-like object that incurs zero copy for slicing
        workaround(bool): removing trailing zero bytes in non-VCL NALUs
    '''

    # FIXME Assume at most three NALUs for fast parsing
    pos = 0
    start = 0
    count = 0
    start24or32 = -1
    while pos < len(bitstream):
        next24 = bitstream[pos:pos + 3]
        next32 = bitstream[pos:pos + 4]
        if next24 == START_CODE24:
            if start < pos:
                count += 1
                header = bitstream[start + start24or32]
                trailing0 = 0
                if workaround:
                    # FIXME Three consecutive zero bytes may be rejected by e.g. KVS
                    # - encoder simply inserts an additionoal zero byte
                    # - encoder forgot to set the stop bit to 1
                    # - encoder forgot to insert an additional byte with MSB set to 1
                    while bitstream[pos - 1 - trailing0] == 0x00:
                        trailing0 += 1
                        logging.warning(
                            f"Skip NALU trailing zero byte at pos {pos - 1 - trailing0}"
                        )
                yield (start, *parseNALUHeader(header)), bitstream[start:pos -
                                                                   trailing0]
            start24or32 = 24 // 8
            start = pos
            pos += start24or32
            header = bitstream[pos]
            _, _, type = parseNALUHeader(header)
            if type in (NALU_t.NIDR, NALU_t.IDR):
                # XXX Skip to the end for speedup
                pos = len(bitstream)
        elif next32 == START_CODE32:
            if start < pos:
                header = bitstream[start + start24or32]
                trailing0 = 0
                if workaround:
                    # FIXME Three consecutive zero bytes may be rejected by e.g. KVS
                    # - encoder simply inserts an additionoal zero byte
                    # - encoder forgot to set the stop bit to 1
                    # - encoder forgot to insert an additional byte with MSB set to 1
                    while bitstream[pos - 1 - trailing0] == 0x00:
                        trailing0 += 1
                        logging.warning(
                            f"Skip NALU trailing zero byte at pos {pos - 1 - trailing0}"
                        )
                yield (start, *parseNALUHeader(header)), bitstream[start:pos -
                                                                   trailing0]
            start24or32 = 32 // 8
            start = pos
            pos += start24or32
            header = bitstream[pos]
            _, _, type = parseNALUHeader(header)
            if type in (NALU_t.NIDR, NALU_t.IDR):
                # XXX Skip to the end for speedup
                pos = len(bitstream)
        else:
            pos += 1

    # Last NALU to the end
    if start < pos:
        header = bitstream[start + start24or32]
        forbidden, ref_idc, type = parseNALUHeader(header)
        if workaround and type not in (NALU_t.IDR, NALU_t.NIDR):
            # FIXME Three consecutive zero bytes may be rejected by e.g. KVS
            # - encoder simply inserts an additionoal zero byte
            # - encoder forgot to set the stop bit to 1
            # - encoder forgot to insert an additional byte with MSB set to 1
            trailing0 = 0
            while bitstream[pos - 1 - trailing0] == 0x00:
                trailing0 += 1
                logging.warning(
                    f"Skip NALU(type={type}) with trailing zero byte at pos {pos - 1 - trailing0}"
                )
            yield (start, forbidden, ref_idc,
                   type), bitstream[start:pos - trailing0]
        else:
            yield (start, forbidden, ref_idc, type), bitstream[start:pos]
예제 #26
0
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        if args.arch in X101_WSL:
            # model = torch.hub.load('facebookresearch/WSL-Images', args.arch)
            from ..models.backbone import resnext101
            mult = 2 * (X101_WSL.index(args.arch))  # 0, 2, 4, 6
            mult = 1 if mult == 0 else mult         # 1, 2, 4, 6
            model = resnext101(pretrained=True, classifier=True, groups=32, width_per_group=8 * mult)
        else:
            model = models.__dict__[args.arch](pretrained=True)
        if args.deploy:
            if args.gpu is None:
                args.gpu = 0
                logging.warning("TensorRT deployment forces using only one GPU")
            model = deploy(model, args.arch, 
                           data=f"{args.data}/val", 
                           batch_size=args.batch_size,
                           resize=256,
                           size=tuple(map(int, args.deploy_size)), 
                           fp16=args.deploy_fp16, 
                           int8=args.deploy_int8,
                           strict_type_constraints=args.deploy_strict,
                           reload=args.deploy_reload)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            model.cuda()
        else:
            model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    if not args.evaluate:
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
            }, is_best)
예제 #27
0
파일: driver.py 프로젝트: necla-ml/ML-WS
    def handler(self, source, headers, message):
        msg = {}
        if source == SUBSCRIPTION_ID.JOB:
            message = json.loads(message)
            correlation_id = headers.get('correlation-id')
            msg_type = headers.get('type')
            stream_id = message.get('stream_id')

            msg = {
                'stream_id': stream_id,
                'timestamp': time.time()
            }
            if int(correlation_id) == stream_id:
                # Request to turn off streaming received
                if msg_type == MSG_TYPE.OFF:
                    logging.warning(f'[{self.name}] Stopping streaming job for stream_id: {stream_id}')
                    self.streamer.stop(timeout=STREAMER_JOIN_TIMEOUT)
                    self.broker.subscribe(
                        destination=self.job_destination,
                        subscription_id=SUBSCRIPTION_ID.JOB,
                        headers={
                            'selector': 'JMSCorrelationID=1'
                        }
                    )
                    msg['msg_type'] = MSG_TYPE.OFF
                    logging.info(f'[{self.name}] Listening to streaming queue with selector: {self.broker.headers}')
                elif msg_type == MSG_TYPE.RELOAD:
                    # stream attribute changed, update args and restart streaming
                    logging.info(f'[{self.name}] Restarting streaming on changes in attributes')
                    self.streamer.args = message
                    # stop streaming thread
                    self.streamer.stop(timeout=STREAMER_JOIN_TIMEOUT)
                    # start streaming thread with latest args
                    self.streamer.start()
                    self.broker.subscribe(
                        destination=self.job_destination,
                        subscription_id=SUBSCRIPTION_ID.JOB,
                        headers={
                            'selector': f'JMSCorrelationID={stream_id}'
                        }
                    )
                    msg['payload'] = 'Streaming restarted on changes in attributes'
                    msg['msg_type'] = MSG_TYPE.RELOAD
                else:
                    logging.warning(f'[{self.name}] Skipping: Invalid message type')
            else:
                if correlation_id == '1':
                    # new streaming job request
                    # subscribe to individual stream attribute changes e.g fps, profile, etc 
                    # ==> same queue but filter based on correlation id value using selector
                    self.streamer.args = message
                    self.streamer.start()
                    msg['msg_type'] = MSG_TYPE.STREAMING
                    self.broker.subscribe(
                        destination=self.job_destination,
                        subscription_id=SUBSCRIPTION_ID.JOB,
                        headers={
                            'selector': f'JMSCorrelationID={stream_id}'
                        }
                    )
                    logging.info(f'[{self.name}] Listening to streaming queue with selector: {self.broker.headers}')
            
            # put msg to monitor queue ==> event_queue
            self.monitor.put_msg(msg)
        else:
            # SUBSCRIPTION_ID.ADMIN
            logging.warning('Admin Listener not implemented yet')
예제 #28
0
def load_state_dict_from_url(url,
                             model_dir=None,
                             map_location=None,
                             force_reload=False,
                             progress=True,
                             check_hash=False,
                             file_name=None):
    # FIXME Temporary workaround for pytorch-1.6.0 introducing new checkpoint save in zip format
    # Added argument: force_reload
    # Alternative url scheme: s3
    r"""Loads the Torch serialized object at the given URL.
    If downloaded file is a zip file, it will be automatically decompressed.
    If the object is already present in `model_dir`, it's deserialized and returned.
    The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
    `hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.

    Args:
        url (string): URL of the object to download
        model_dir (string, optional): directory in which to save the object
        map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
        progress (bool, optional): whether or not to display a progress bar to stderr.
            Default: True
        check_hash(bool, optional): If True, the filename part of the URL should follow the naming convention
            ``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more
            digits of the SHA256 hash of the contents of the file. The hash is used to
            ensure unique names and to verify the contents of the file.
            Default: False
        file_name (string, optional): name for the downloaded file. Filename from `url` will be used if not set.
    
    Example:
        >>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
    """
    # Issue warning to move data if old env is set
    if os.getenv('TORCH_MODEL_ZOO'):
        warnings.warn(
            'TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')

    if model_dir is None:
        hub_dir = get_dir()
        model_dir = os.path.join(hub_dir, 'checkpoints')

    try:
        os.makedirs(model_dir)
    except OSError as e:
        if e.errno == errno.EEXIST:
            # Directory already exists, ignore.
            pass
        else:
            # Unexpected OSError, re-raise.
            raise

    spec = urlparse(url)
    filename = os.path.basename(spec.path)
    if file_name is not None:
        filename = file_name
    cached_file = os.path.join(model_dir, filename)

    download = True
    if os.path.exists(cached_file):
        if force_reload:
            os.unlink(cached_file)
            logging.warning(
                f"Forced removing existing download: {cached_file}")
        else:
            download = False
            logging.warning(
                f"Download exists: {cached_file}, specify force_reload=True to remove if necessary"
            )

    if download:
        sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
        try:
            if spec.scheme == 's3':
                bucket = spec.netloc
                key = spec.path[1:]
                download_s3(bucket, key, cached_file, progress=progress)
            else:
                hash_prefix = None
                if check_hash:
                    r = HASH_REGEX.search(
                        filename)  # r is Optional[Match[str]]
                    hash_prefix = r.group(1) if r else None
                download_url_to_file(url,
                                     cached_file,
                                     hash_prefix,
                                     progress=progress)
        except Exception as e:
            raise IOError(f"Failed to download to {cached_file}: {e}")
    '''
    if _is_legacy_zip_format(cached_file):
        return _legacy_zip_load(cached_file, model_dir, map_location)
    '''
    logging.info(f'Loading checkpoint from {cached_file}')
    return torch.load(cached_file, map_location=map_location)
예제 #29
0
파일: api.py 프로젝트: necla-ml/ML
from ml import logging

try:
    from ml.vision.transforms import *
    from ml.vision.transforms import functional
except Exception as e:
    logging.warning(
        f"{e}, ml.vision.transforms unavailable, run `mamba install ml-vision -c NECLA-ML` to install"
    )
예제 #30
0
파일: avsource.py 프로젝트: necla-ml/ML-WS
    def read_video(self, session, format='BGR'):
        meta = session['video']
        stream = meta['stream']
        codec = meta['codec']
        workaround = meta['workaround']
        while True:
            try:
                pkt = next(stream)
            except StopIteration:
                pkt = None
            now = time.time()
            prev = meta.get('prev', None)
            if prev is None:
                if not pkt.is_keyframe:
                    # Some RTSP source may not send key frame to begin with e.g. wisecam
                    logging.warning(
                        f"No key frame to begin with, skip through")
                    session['start'] = now
                    continue
                meta['keyframe'] = pkt.is_keyframe
                meta['time'] = session['start']
                streams = session['streams']
                sformat = session['format']

                # XXX Stream container package format determines H.264 NALUs in AVCC or Annex B.
                # TODO Streaming NALUs in AVCC
                if 'hls' in sformat or 'rtsp' in sformat or '264' in sformat:
                    # XXX In case of out of band CPD: SPS/PPS in AnnexB.
                    CPD = []
                    if codec.extradata is not None:
                        for (pos, _, _,
                             type), nalu in NALUParser(codec.extradata,
                                                       workaround=workaround):
                            if hasStartCode(nalu):
                                CPD.append(nalu)
                                logging.info(
                                    f"CPD {NALU_t(type).name} at {pos}: {nalu[:8]} ending with {nalu[-1:]}"
                                )
                            else:
                                logging.warning(
                                    f"Invalid CPD NALU({type}) at {pos}: {nalu[:8]} ending with {nalu[-1:]}"
                                )
                                if not CPD:
                                    # Skip all
                                    break
                    NALUs = []
                    if workaround:
                        # FIXME workaround before KVS MKVGenerator deals with NALUs ending with a zero byte
                        #   https://github.com/awslabs/amazon-kinesis-video-streams-producer-sdk-cpp/issues/491
                        for (pos, _, _,
                             type), nalu in NALUParser(memoryview(pkt),
                                                       workaround=workaround):
                            assert hasStartCode(
                                nalu
                            ), f"frame[{meta['count']+1}] NALU(type={type}) at {pos} without START CODE: {nalu[:8].tobytes()}"
                            if type in (NALU_t.SPS, NALU_t.PPS):
                                if CPD:
                                    # NOTE: some streams could have multiple UNSPECIFIED(0) NALUs within a single packet with SPS/PPS
                                    #assert len(CPD) == 2, f"len(CPD) == {len(CPD)}, not 2 for SPS/PPS"
                                    ordinal = type - NALU_t.SPS
                                    if nalu == CPD[ordinal]:
                                        logging.info(
                                            f"frame[{meta['count']+1}] same {NALU_t(type).name}({nalu[:8].tobytes()}) at {pos} as in CPD({CPD[ordinal][:8]})"
                                        )
                                    else:
                                        # FIXME may expect the CPD to be inserted in the beginning?
                                        logging.warning(
                                            f"frame[{meta['count']+1}] inconsistent {NALU_t(type).name}({nalu[:8].tobytes()}) at {pos} with CPD({CPD[ordinal][:8]})"
                                        )
                                        print(f"CPD {NALU_t(type).name}:",
                                              CPD[ordinal])
                                        print(f"NALU {NALU_t(type).name}:",
                                              nalu.tobytes())
                                        # XXX bitstream may present invalid CPD => replacement with bitstream SPS/PPS
                                        CPD[ordinal] = nalu
                                else:
                                    NALUs.append(nalu)
                                    logging.info(
                                        f"frame[{meta['count']+1}] {NALU_t(type).name} at {pos}: {nalu[:8].tobytes()} ending with {nalu[-1:].tobytes()}"
                                    )
                            # XXX KVS master is ready to filter out non-VCL NALUs as part of the CPD
                            # elif type in (NALU_t.IDR, NALU_t.NIDR):
                            elif type in (NALU_t.AUD, NALU_t.SEI, NALU_t.IDR,
                                          NALU_t.NIDR):
                                NALUs.append(nalu)
                                logging.info(
                                    f"frame[{meta['count']+1}] {NALU_t(type).name} at {pos}: {nalu[:8].tobytes()}"
                                )
                            else:
                                # FIXME may expect CPD to be inserted in the beginning?
                                logging.warning(
                                    f"frame[{meta['count']+1}] skipped unexpected NALU(type={type}) at {pos}: {nalu[:8].tobytes()}"
                                )
                        logging.info(
                            f"{pkt.is_keyframe and 'key ' or ''}frame[{meta['count']}] combining CPD({len(CPD)}) and NALUs({len(NALUs)})"
                        )
                    else:
                        NALUs.append(memoryview(pkt))
                        logging.info(
                            f"{pkt.is_keyframe and 'key ' or ''}frame[{meta['count']}] prepending CPD({len(CPD)})"
                        )
                    packet = av.Packet(bytearray(b''.join(CPD + NALUs)))
                    packet.dts = pkt.dts
                    packet.pts = pkt.pts
                    packet.time_base = pkt.time_base
                    pkt = packet
                    if pkt.pts is None:
                        logging.warning(
                            f"Initial packet dts/pts={pkt.dts}/{pkt.pts}, time_base={pkt.time_base}"
                        )
                    elif pkt.pts > 0:
                        logging.warning(
                            f"Reset dts/pts of 1st frame from {pkt.pts} to 0")
                        pkt.pts = pkt.dts = 0
                elif 'dash' in sformat:
                    # TODO In case of out of band CPD: SPS/PPS in AVCC.
                    logging.info(f"DASH AVCC extradata: {codec.extradata}")
                    logging.info(
                        f"pkt[:16]({pkt.is_keyframe}) {memoryview(pkt)[:16].tobytes()}"
                    )
            else:
                keyframe = pkt.is_keyframe
                logging.debug(
                    f"packet[{meta['count']}] {keyframe and 'key ' or ''}dts/pts={pkt.dts}/{pkt.pts}, time_base={pkt.time_base}, duration={pkt.duration}"
                )
                if 'hls' in sformat or 'rtsp' in sformat or '264' in sformat:
                    NALUs = []
                    if workaround:
                        for (pos, _, _,
                             type), nalu in NALUParser(memoryview(pkt),
                                                       workaround=workaround):
                            # assert hasStartCode(nalu), f"frame[{meta['count']+1}] NALU(type={type}) at {pos} without START CODE: {nalu[:8].tobytes()}"
                            # FIXME KVS master is not ready to take AUD/SEI as part of the CPD
                            # if type in (NALU_t.SPS, NALU_t.PPS, NALU_t.IDR, NALU_t.NIDR):
                            if type in (NALU_t.AUD, NALU_t.SEI, NALU_t.SPS,
                                        NALU_t.PPS, NALU_t.IDR, NALU_t.NIDR):
                                NALUs.append(nalu)
                                logging.debug(
                                    f"frame[{meta['count']+1}] {NALU_t(type).name} at {pos}: {nalu[:8].tobytes()}"
                                )
                            else:
                                # FIXME may expect CPD to be inserted?
                                logging.debug(
                                    f"frame[{meta['count']+1}] skipped NALU(type={type}) at {pos}: {nalu[:8].tobytes()} ending with {nalu[-1:].tobytes()}"
                                )
                    else:
                        NALUs.append(memoryview(pkt))
                    # XXX Assme no SPS/PPS change
                    packet = av.Packet(bytearray(b''.join(NALUs)))
                    packet.dts = pkt.dts
                    packet.pts = pkt.pts
                    packet.time_base = pkt.time_base
                    pkt = packet
                frame = prev
                if session['decoding']:
                    try:
                        frames = codec.decode(prev)
                        if not frames:
                            logging.warning(
                                f"Decoded nothing, continue to read...")
                            meta['prev'] = pkt
                            meta['count'] += 1
                            continue
                    except Exception as e:
                        logging.error(
                            f"Failed to decode video packet of size {prev.size}: {e}"
                        )
                        raise e
                    else:
                        # print(prev, frames)
                        frame = frames[0]
                        meta['width'] = frame.width
                        meta['height'] = frame.height
                        if format == 'BGR':
                            frame = frame.to_rgb().to_ndarray()[:, :, ::-1]
                        elif format == 'RGB':
                            frame = frame.to_rgb().to_ndarray()
                if session['rt']:
                    '''
                    Live source from network or local camera encoder.
                    Bitstream contains no pts but frame duration.
                    Adaptive frame duration on drift from wall clock:
                        - Faster for long frame buffering
                        - Fall behind for being slower than claimed FPS: resync as now
                    '''
                    if pkt.pts is not None and not meta['drifting']:
                        # Check if drifting
                        if prev.pts is None:
                            prev.dts = prev.pts = 0
                            logging.warning(
                                "Reset previous packet dts/pts from None to 0")
                        duration = float((pkt.pts - prev.pts) * pkt.time_base)
                        # assert duration > 0, f"pkt.pts={pkt.pts}, prev.pts={prev.pts}, pkt.time_base={pkt.time_base}, pkt.duration={pkt.duration}, prev.duration={prev.duration}, duration={duration}"
                        if duration <= 0:
                            # FIXME RTSP from Dahua/QB and WiseNet/Ernie
                            pts = prev.pts + (meta['duration'] /
                                              pkt.time_base) / 2
                            duration = float((pts - prev.pts) * pkt.time_base)
                            logging.warning(
                                f"Non-increasing pts: pkt.pts={pkt.pts}, prev.pts={prev.pts} => pts={pts}, duration={duration}"
                            )
                            pkt.pts = pts

                        timestamp = meta['time'] + duration
                        if meta['adaptive']:
                            # adaptive frame duration only if not KVS
                            diff = abs(timestamp - now)
                            threshold = meta['thresholds']['drifting']
                            if diff > threshold:
                                meta['drifting'] = True
                                logging.warning(
                                    f"Drifting video timestamps: abs({timestamp:.3f} - {now:.3f}) = {diff:.3f} > {threshold}s"
                                )
                    if pkt.pts is None or meta['drifting']:
                        # Real-time against wall clock
                        duration = now - meta['time']
                        duration = min(1.5 / meta['fps'], duration)
                        duration = max(0.5 / meta['fps'], duration)
                        meta['duration'] = duration
                        yield meta, frame
                        meta['time'] += duration
                    else:
                        meta['duration'] = duration
                        yield meta, frame
                        meta['time'] = timestamp
                else:
                    # TODO: no sleep for being handled by renderer playback
                    # Simulating RT
                    meta['duration'] = 1.0 / meta['fps']
                    slack = (meta['time'] + meta['duration']) - now
                    if slack > 0:
                        logging.debug(
                            f"Sleeping for {slack:.3f}s to simulate RT source")
                        time.sleep(slack)
                    yield meta, frame
                    meta['time'] += meta['duration']
                meta['keyframe'] = keyframe
            if pkt.size == 0:
                logging.warning(f"EOF/EOS on empty packet")
                return None
            else:
                meta['prev'] = pkt
                meta['count'] += 1