예제 #1
0
    def test_get_episode(self):
        ''' Tests getting episode from video's field. '''

        configure_for_unittest()
        from model import Video, Show, Season, Episode

        show = Show(title='House of Cards').save()
        season1 = Season(show=show, number=1).save()
        season2 = Season(show=show, number=2).save()
        episode1 = Episode(season=season1,
                           number=1,
                           release_date=dt(2010, 1, 1)).save()
        episode2 = Episode(season=season1, number=2).save()
        episode3 = Episode(season=season2, number=1).save()
        video = Video(link='vk.com').save()

        show.seasons.connect(season1)
        show.seasons.connect(season2)
        season1.episodes.connect(episode1)
        season1.episodes.connect(episode2)
        season2.episodes.connect(episode3)
        episode1.videos.connect(video)

        video.refresh()

        self.assertEqual(video.episode.get().number, 1)
예제 #2
0
def main():
    with TransactionManager.transaction():
        User.init_table()
        Video.init_table()
        Config.init_table()

    print("Succeeded to init db.")

    return 0
예제 #3
0
def extract_video_tweets(twapi_search_response: dict) -> List[Video]:
    video_list = []

    for tweet in twapi_search_response:
        selected_video = select_video_url_from_tweet(tweet)
        if selected_video is None:
            continue
        else:
            video_url = selected_video[0]
            video_type = selected_video[1]

        if "retweeted_status" in tweet or "quoted_status" in tweet:
            if "retweeted_status" in tweet:
                video_source_tweet = tweet["retweeted_status"]
            elif "quoted_status" in tweet:
                video_source_tweet = tweet["quoted_status"]

            author = User(
                user_id=video_source_tweet["user"]["id"],
                name=video_source_tweet["user"]["name"],
                screen_name=video_source_tweet["user"]["screen_name"],
                thumbnail_url=video_source_tweet["user"]["profile_image_url_https"])

            retweeter = User(
                user_id=tweet["user"]["id"],
                name=tweet["user"]["name"],
                screen_name=tweet["user"]["screen_name"],
                thumbnail_url=tweet["user"]["profile_image_url_https"])

            vid = Video(
                tweet_id=tweet["id"],
                author=author,
                retweeted_user=retweeter,
                body=video_source_tweet["full_text"],
                created_at=twitter_date_to_datetime(tweet["created_at"]),
                video_url=video_url,
                video_type=video_type)

        else:
            author = User(
                user_id=tweet["user"]["id"],
                name=tweet["user"]["name"],
                screen_name=tweet["user"]["screen_name"],
                thumbnail_url=tweet["user"]["profile_image_url_https"])

            vid = Video(
                tweet_id=tweet["id"],
                author=author,
                retweeted_user=None,
                body=tweet["full_text"],
                created_at=twitter_date_to_datetime(tweet["created_at"]),
                video_url=video_url,
                video_type=video_type)

        video_list.append(vid)

    return video_list
예제 #4
0
    def test_create_duplicates(self):
        ''' Tests raise exception with trying to create duplicates. '''

        configure_for_unittest()
        from model import Video
        from neomodel import UniqueProperty

        video1 = Video(link='https://vk.com/blablabla').save()
        video2 = Video(link='vk.com/blablabla')

        with self.assertRaises(UniqueProperty):
            video2.save()
예제 #5
0
    def test_create_duplicates(self):
        ''' Tests raise exception with trying to create duplicates. '''

        configure_for_unittest()
        from model import Video
        from neomodel import UniqueProperty

        video1 = Video(link='https://vk.com/blablabla').save()
        video2 = Video(link='vk.com/blablabla')

        with self.assertRaises(UniqueProperty):
            video2.save()
예제 #6
0
def stabilization(optical_flow_method, debug: bool = False, **kwargs):
    """
    Perform video stabilization using the given optical flow method.

    Idea: test some metric using a known logo. Using ORB matching we could detect if it moves.

    :param optical_flow_method: the optical flow method to use
    :param debug: whether to show debug plots
    """
    video = Video('../datasets/stabilization/piano')
    feature_params = dict(maxCorners=500,
                          qualityLevel=0.3,
                          minDistance=7,
                          blockSize=7)
    previous_frame = None
    accum_flow = np.zeros(2)
    count = 0
    for i, frame in tqdm(enumerate(video.get_frames()),
                         total=len(video),
                         file=sys.stdout):
        rows, cols, _ = frame.shape
        if previous_frame is not None:
            if i % 4 == 0:
                p0 = cv2.goodFeaturesToTrack(cv2.cvtColor(
                    previous_frame, cv2.COLOR_BGR2GRAY),
                                             mask=None,
                                             **feature_params)
                flow = optical_flow_method(previous_frame, frame, p0)
                if debug:
                    show_optical_flow_arrows(previous_frame, flow)

                m = np.mean(flow[np.logical_or(flow[:, :, 0] != 0,
                                               flow[:, :, 1] != 0)],
                            axis=(0, 1))
                if not np.isnan(accum_flow).any():
                    accum_flow += -m
                transform = np.float32([[1, 0, accum_flow[0]],
                                        [0, 1, accum_flow[1]]])
                frame2 = cv2.warpAffine(frame, transform, (cols, rows))

                if debug:
                    plt.figure()
                    plt.imshow(cv2.cvtColor(frame2, cv2.COLOR_BGR2RGB))
                    plt.axis('off')
                    plt.show()
                cv2.imwrite("../video/block/OrigianlFrame%04d.jpg" % count,
                            frame)  # save frame as JPEG file
                cv2.imwrite("../video/block/StabilizedFrame%04d.jpg" % count,
                            frame2)  # save frame as JPEG file

                count += 1
        previous_frame = frame
예제 #7
0
def upload():
    if 'video' not in request.files:
        return json_error("Video file not found")

    video_file = request.files['video']
    filename = secure_filename(video_file.filename)
    ext = filename.split('.')[-1]
    new_filename = generate_random_string() + '.' + ext
    video_file.save(os.path.join(app.config['UPLOAD_DIR'], new_filename))

    video = Video(filename=new_filename)
    video.save()

    return jsonify(video.to_dict())
예제 #8
0
def gaussian_model(video: Video,
                   frame_start: int,
                   background_mean: np.ndarray,
                   background_std: np.ndarray,
                   alpha: float = 2.5,
                   pixel_value: PixelValue = PixelValue.GRAY,
                   total_frames: int = None,
                   disable_tqdm=False) -> Iterator[np.ndarray]:
    for im in tqdm(video.get_frames(frame_start),
                   total=total_frames,
                   file=sys.stdout,
                   desc="Non-adaptive gaussian model...",
                   disable=disable_tqdm):

        if pixel_value == PixelValue.GRAY:
            im_values = np.mean(im, axis=-1) / 255
        elif PixelValue.HSV:
            im_values = cv2.cvtColor(im, cv2.COLOR_BGR2HSV)[:, :, 0] / 180
        else:
            raise Exception

        mask = (np.abs(im_values) - background_mean) >= (alpha *
                                                         (background_std +
                                                          (5 / 255)))

        yield im, mask.astype(np.uint8) * 255
예제 #9
0
def get_video(vid):
    video = Video.get_or_none(id=vid)

    if video is None:
        return json_error("Video not found", 404)

    return jsonify(video.to_dict())
예제 #10
0
def gaussian_model_adaptive(video: Video,
                            train_stop_frame: int,
                            background_mean: np.ndarray,
                            background_std: np.ndarray,
                            alpha: float = 2.5,
                            rho: float = 0.1,
                            pixel_value: PixelValue = PixelValue.GRAY,
                            total_frames: int = None,
                            disable_tqdm=False) -> Iterator[np.ndarray]:
    for im in tqdm(video.get_frames(train_stop_frame),
                   total=total_frames,
                   file=sys.stdout,
                   desc='Adaptive gaussian model...',
                   disable=disable_tqdm):

        if pixel_value == PixelValue.GRAY:
            im_values = np.mean(im, axis=-1) / 255
        elif PixelValue.HSV:
            im_values = cv2.cvtColor(im, cv2.COLOR_BGR2HSV)[:, :, 0] / 180
        else:
            raise Exception

        mask = (np.abs(im_values) -
                background_mean) >= (alpha * (background_std + 5 / 255))
        background_mean = rho * im_values + (1 - rho) * background_mean
        background_std = np.sqrt(rho *
                                 np.power((im_values - background_mean), 2) +
                                 (1 - rho) * np.power(background_std, 2))

        yield im, mask.astype(np.uint8) * 255
예제 #11
0
def get_background_model(video: Video,
                         train_stop_frame: int,
                         total_frames: int = None,
                         pixel_value: PixelValue = PixelValue.GRAY,
                         disable_tqdm=False) -> (np.ndarray, np.ndarray):
    background_list = None
    i = 0
    for im in tqdm(video.get_frames(0, train_stop_frame),
                   total=total_frames,
                   file=sys.stdout,
                   desc='Training model...',
                   disable=disable_tqdm):
        if background_list is None:
            background_list = np.zeros(
                (im.shape[0], im.shape[1], train_stop_frame), dtype=np.int16)

        if pixel_value == PixelValue.GRAY:
            background_list[:, :, i] = np.mean(im, axis=-1)
        elif PixelValue.HSV:
            background_list[:, :, i] = cv2.cvtColor(im,
                                                    cv2.COLOR_BGR2HSV)[:, :, 0]
        else:
            raise Exception
        i += 1

    if pixel_value == PixelValue.GRAY:
        background_mean = np.mean(background_list, axis=-1) / 255
        background_std = np.std(background_list, axis=-1) / 255
    elif PixelValue.HSV:
        background_mean = np.mean(background_list, axis=-1) / 180
        background_std = np.std(background_list, axis=-1) / 180
    else:
        raise Exception

    return background_mean, background_std
예제 #12
0
    def import_from_lines(self, lines):
        """
        从多行数据录入数据库
        :param lines: 待处理数据
        :return:  读取到的记录数量,成功录入的记录数量
        """
        (total, succ) = (0, 0)

        for line in lines:
            total += 1
            items = line.strip().split('\t')
            # 如果所读取的文件编码不为 utf-8,则需要 line.encode('<the_coding_of_str>').decode('utf-8').strip().split('\t')
            if len(items) < 7:
                CrawlerLogger.logger.info(
                    'line format error: {0}'.format(line))
                continue
            tmp_video = Video(0, items[0], items[1], items[2], items[3],
                              items[4], items[5], items[6])

            if not self.insert_one_video(tmp_video):
                CrawlerLogger.logger.info(
                    'insert line failed: {0}'.format(line))
            else:
                succ += 1
        return total, succ
예제 #13
0
def put_video_in_queue():
    if 'video_id' not in request.json:
        return json_error("Video not found")

    if 'qualities' not in request.json:
        return json_error("Target qualities not found")

    qualities = request.json['qualities']
    video_id = request.json['video_id']

    video = Video.get_or_none(id=video_id)
    if video is None:
        return json_error("Video not found")

    video_queue = WaitingQueue.get_or_none(WaitingQueue.video == video)
    if video_queue is not None:
        return json_error("Already in queue")

    video_queue = WaitingQueue(video=video)
    video_queue.save()

    success = save_video_qualities(video, qualities)
    if not success:
        video_queue.delete_instance()

    return jsonify(video.to_dict())
예제 #14
0
def add(user_id, video_id):

    #Taking the old video ti present it to the user and also to mention the video last publisher to post it
    #Interactivly

    user = session.query(User).filter_by(id = user_id).first()
    curr_video = session.query(Video).filter_by(id = video_id).first()

    if request.method == 'GET':
        return render_template('add.html', user_id = user_id, video_id = video_id)


    #Adding the video with refrence to the old one
    else:
        video         = request.form.get('video')
        description   = request.form.get('description')

        owner         = user.id

        curr_video.other_video = video
        publish = False        

        new_vid = Video(video = video, description = description, publish = publish, owner = owner)
        session.add(new_vid)
        session.commit()

        return redirect(url_for('homepage',user_id = user_id))
예제 #15
0
def load_videos():
	"""Load videos from seed data into database"""

	with open("seed_data/videos.txt") as videos: 
		for row in videos: 
			video = row.rstrip().split("|")

			youtube = True if video[4] == "True" else False
			hidden = True if video[5] == "True" else False

			kwargs = dict(
			video_id = video[0],
			user_id = video[1],
			video = video[2],
			category = video[3], 
			youtube = youtube, 
			hidden = hidden 
			)

			keys_to_remove = []

			for key in kwargs.keys(): 
				if kwargs[key] == "":
					keys_to_remove.append(key)

			for key in keys_to_remove:
				del kwargs[key]

			video = Video(**kwargs)

			db.session.add(video)

	db.session.commit()
예제 #16
0
 def openFile(self, filename):
     filetype = filename.split('.')[-1]
     self.generateButton.setText("Loading...")
     self.infoLabel.setText("File opened: " + filename)
     print("Opening " + filename)
     # start the video player or start a worker thread depending on file type
     if filetype in self.videofiletypes:
         self.currentFile = Video(filepath=filename, colortype="rgb")
         self.generateButton.setText("Generate images from video")
         self.generateButton.setDisabled(False)
         self.startVideoPlayer()
     elif filetype in self.metadatatypes:
         self.metafilename = filename
         self.startWorker(ip.loadCsv, self.setCurrentFrames,
                          self.fillGallery, self.metafilename,
                          self.currentFrames)
     elif filetype == 'pkl':
         self.currentFile = filename
         self.startWorker(ip.loadPickle, self.setCurrentFrames,
                          self.fillGallery, self.currentFile)
     elif filetype == 'h5':
         self.currentFile = filename
         self.startWorker(ip.loadhdf5, self.setCurrentFrames,
                          self.fillGallery, self.currentFile,
                          self.currentFrames)
     else:
         # invalid file type selected
         self.showWarning('FileType')
def main():
    alpha_values = np.linspace(1.5, 3, 20)
    rho_values = np.logspace(-2, -0.1, 20)

    # Ensure cache
    video = Video("../datasets/AICity_data/train/S03/c010/frames")
    get_background_model(video,
                         int(2141 * 0.25),
                         total_frames=int(2141 * 0.25),
                         disable_tqdm=False)

    # Best alpha: 1.75

    # mAP_list = Parallel(n_jobs=4)(delayed(w2_map_alpha)(alpha) for alpha in tqdm(alpha_values))
    mAP_list = Parallel(n_jobs=3)(delayed(w2_map_alpha)(1.75, rho)
                                  for rho in tqdm(rho_values))
    """mAP_list = [0.18656629994209614, 0.23257430508572247, 0.2333781161367368, 0.18301435406698566,
                0.1773032336790726, 0.1762025561112319, 0.12792207792207794, 0.17066218427456575,
                0.12438077386530996, 0.12091293755609694, 0.11872632575757576, 0.1189064558629776,
                0.15132634758802985, 0.157589106928314, 0.26284443191338397, 0.39380709780347006,
                0.43192630414348515, 0.357941584643725, 0.3186317361126976, 0.20596422790608496]"""

    plt.figure()
    plt.plot(rho_values, mAP_list)
    plt.xlabel(r'$\rho$ threshold')
    plt.ylabel('mAP')
    plt.show()
예제 #18
0
파일: crud.py 프로젝트: monkeysaa/ed-vid
def create_video(link, title, notes = ""):
    """Create and return a video."""

    video = Video(link = link, title = title, notes = notes)

    db.session.add(video)
    db.session.commit()

    return video
예제 #19
0
	def get(self):
		q = Video.all()

		template_values = {
			'list_videos':q.run(limit=10)
		}

		template = jinja_environment.get_template('/templates/video_list.html')
		self.response.out.write(template.render(template_values))		
예제 #20
0
def thead_list(play_list_id):
    play = PlayList.get_by_id(play_list_id)
    while True:
        first_video = play.videos\
            .where(Video.is_progress == 0) \
            .where(Video.is_completed == 0) \
            .first()

        if first_video is None:
            break
        updated = Video.update(is_progress=1) \
            .where(Video.is_progress == 0) \
            .where(Video.id == first_video.id).execute()
        if updated == 1:
            client = Bilibili()
            client.download_by_id(first_video.id)
            Video.update(is_completed=True).where(
                Video.id == first_video.id).execute()
예제 #21
0
    def download_by_id(self, video_id):
        video = Video.select().where(Video.id == video_id).first()
        play = PlayList.select().where(
            PlayList.id == video.play_list_id).first()

        print("开始下载" + video.title + ",视频集:" + play.title)
        download_path = Bilibili.download_path + '/' + play.title
        if os.path.exists(download_path) is False:
            os.mkdir(download_path)
        self.get_url(video.cid, video.title, download_path)
예제 #22
0
def create_video(video_path, description, date_posted):
    """Create new video (info)."""

    video = Video(video_path=video_path,
                  description=description,
                  date_posted=date_posted)

    db.session.add(video)
    db.session.commit()

    return video
예제 #23
0
def videolist_since():
    select_params = {}
    if "since_id" in flask.request.args:
        select_params["since_id"] = int(flask.request.args["since_id"])
    if "count" in flask.request.args:
        select_params["count"] = int(flask.request.args["count"])

    videos = Video.select_since(**select_params)
    videos_dict = [video.to_dict() for video in videos]
    videos_dict = FilterPluginExecutor.get_instance()(videos_dict)
    return flask.Response(response=json.dumps(videos_dict))
예제 #24
0
def upload_video():
    """Handle file uploads"""

    # if it's a get - render registration form
    if request.method == "GET":
        """Load form for user to upload new video"""

        case_id = request.args.get('case_id')

        return render_template('upload-video.html', case_id=case_id)

    # otherwise register user in db
    if request.method == "POST":
        """Uploads video to aws"""
        print "\n\n\n\n\n\n\n", request, "\n\n\n\n\n\n"
        video_file = request.files['media']
        print "\n\n\n\n\n\n\n", video_file, "\n\n\n\n\n\n"
        case_id = request.form.get('case_id')
        # video_file = request.files.get("rawvid")
        video_name = video_file.filename
        try:
            transcript_file = request.files.get("tscript")
        except:
            transcript_file = None
        user_id = g.current_user.user_id

        #get deponent name and recorded date
        deponent = request.form.get('name')
        recorded_at = request.form.get('date-taken')

        # add the video to the db
        date_added = datetime.now()
        new_vid = Video(case_id=case_id,
                        vid_name=video_name,
                        added_by=user_id,
                        added_at=date_added,
                        deponent=deponent,
                        recorded_at=recorded_at)
        db.session.add(new_vid)
        db.session.commit()

        if transcript_file:
            script_text = transcript_file.readlines()
            new_script = Transcript(vid_id=new_vid.vid_id, text=script_text)
            db.session.add(new_script)
            db.session.commit()

        # send the upload to a separate thread to upload while the user moves on
        upload = threading.Thread(target=upload_aws_db,
                                  args=(video_file, video_name, case_id,
                                        user_id, socketio)).start()

        return jsonify(case_id)
예제 #25
0
파일: parsers.py 프로젝트: s-perfilev/bcsb
def update_episode_urls(episode):
    urls, new_urls = get_episode_urls(episode), []

    for url in urls:
        try:
            v = Video(link=url).save()
            episode.videos.connect(v)
            new_urls.append(urls)
        except neomodel.UniqueProperty:
            pass

    return new_urls
예제 #26
0
def create_video(channel_name, web_title, youtube_title):
   

    video = Video(channel_name=channel_name,
                  web_title=web_title,
                  youtube_title=youtube_title)

    db.session.add(video)

    db.session.commit()

    return video
예제 #27
0
 def get_url(self, cid, filename, path):
     """获取每个链接并下载"""
     html = self.get_response_by_cid(cid=cid)
     print(html)
     video_list = []
     if len(html['durl']) == 1:
         # 如果只有一个链接,则表示单视频
         print(html['durl'][0])
         Video.update(size=html['durl'][0]['size']).where(
             Video.cid == cid).execute()
         self.download(html['durl'][0]['url'],
                       path + '/' + filename + '.mp4', self.next_headers)
     else:
         # 否则是列表
         temps = []
         for i in html['durl']:
             print(i)
             exit()
             temp = path + '/' + filename + '.tmp'
             temps.append(temp)
             self.download(i['url'], temp, self.next_headers)
     return video_list
예제 #28
0
def w2_map_alpha(alpha):
    video = Video("../datasets/AICity_data/train/S03/c010/frames")
    frames = []
    ious = []
    for im, mask, frame in week2_nonadaptive(video, alpha, disable_tqdm=True):
        frames.append(frame)
        ious.append(frame.get_detection_iou_mean(ignore_classes=True))

    mAP = mean_average_precision(frames)

    print('alpha', alpha, 'mAP', mAP, 'mean IoU', np.mean(ious))

    return mAP
예제 #29
0
def add_video():
    """add video to database"""

    user_id = session["current_user"]
    category = request.form.get("category")
    youtube = request.form.get("youtube")

    print(category)
    print(youtube)

    if youtube:
        kwargs = dict(user_id=user_id,
                      video=youtube,
                      category=category,
                      youtube=True)

        db.session.add(Video(**kwargs))
        db.session.commit()
        print("added youtube")

    else:
        video = functions.save_photo("video")
        print(video)

        kwargs = dict(user_id=user_id,
                      video=video,
                      category=category,
                      youtube=False)

        db.session.add(Video(**kwargs))
        db.session.commit()
        print("added vid")

    user = User.query.get(user_id)
    for vid in user.videos:
        print(vid.video, vid.category, vid.youtube)

    return redirect("/users/{}/my_videos".format(user_id))
예제 #30
0
def create_video(video_title, video_duration, video_url, playlist_id):
    """Create and return a new video to a playlist"""

    video = Video(
        video_title=video_title,
        video_duration=video_duration,
        video_url=video_url,
        playlist_id=playlist_id,
    )

    db.session.add(video)
    db.session.commit()

    return video
예제 #31
0
	def post(self):
		title = TString.trim( self.request.get('txtTitle') )
		link = TString.trim( self.request.get('txtLink') )
		priority = TString.trim( self.request.get('txtPriority') )
		#self.response.out.write(title + ";" + link + ";" + priority)

		
		m = re.search('(http:\/\/)?\w\w\w\.[a-z]+\.[a-z]*.*',link)
		link = m.group(0).split("v=")[1]

		vc = RegisterContext()
		template = jinja_environment.get_template('/templates/video_form.html')

		if priority.isdigit():
			priority = int(priority)
			vc.success_flag = ( priority >= 1 and priority<=5 )
		else:
			vc.success_flag = False

		if vc.success_flag:
			new_v = Video(vd_title=title,vd_link=link,vd_priority=int(priority) )
			new_v.put()
		
		self.response.out.write(template.render({'video_ctx':vc}))
예제 #32
0
    def test_get_episode(self):
        ''' Tests getting episode from video's field. '''

        configure_for_unittest()
        from model import Video, Show, Season, Episode

        show = Show(title='House of Cards').save()
        season1 = Season(show=show, number=1).save()
        season2 = Season(show=show, number=2).save()
        episode1 = Episode(season=season1, number=1, release_date=dt(2010,1,1)).save()
        episode2 = Episode(season=season1, number=2).save()
        episode3 = Episode(season=season2, number=1).save()
        video = Video(link='vk.com').save()

        show.seasons.connect(season1)
        show.seasons.connect(season2)
        season1.episodes.connect(episode1)
        season1.episodes.connect(episode2)
        season2.episodes.connect(episode3)
        episode1.videos.connect(video)

        video.refresh()

        self.assertEqual(video.episode.get().number, 1)
예제 #33
0
def play(play_id):
    play_list = PlayList.select().where(PlayList.id == play_id).first()

    videos = Video.select().where(Video.play_list_id == play_id)
    video_dict = models_to_dict(videos)
    client = Bilibili()
    for video in video_dict:
        file = Bilibili.download_path + '/' + play_list.title + "/" + video[
            'title'] + '.mp4'

        if video['size'] == 0:
            html = client.get_response_by_cid(video['cid'])
            if 'durl' in html.keys() and len(html['durl']) == 1:
                # 如果只有一个链接,则表示单视频
                print(html['durl'][0])
                Video.update(size=html['durl'][0]['size']).where(
                    Video.cid == video['cid']).execute()
            print(html)

        if os.path.exists(file) is True:
            video['file_size'] = os.path.getsize(file)
        else:
            video['file_size'] = 0
    return render_template("play.html", play=play_list, videos=video_dict)
예제 #34
0
def fine_tune_yolo(debug=False):
    video = Video("../datasets/AICity_data/train/S03/c010/frames")
    detection_transform = DetectionTransform()
    classes = utils.load_classes('../config/coco.names')

    hyperparams = parse_model_config('../config/yolov3.cfg')[0]
    learning_rate = float(hyperparams["learning_rate"])
    momentum = float(hyperparams["momentum"])
    decay = float(hyperparams["decay"])
    burn_in = int(hyperparams["burn_in"])

    model = Darknet('../config/yolov3.cfg')
    print(model)
    model.load_weights('../weights/yolov3.weights')
    model.train()
    for module_def, module in zip(model.module_defs, model.module_list):
        if module_def["type"] == "yolo":
            break
        module.train(False)
    if torch.cuda.is_available():
        model = model.cuda()

    optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()),
                     lr=1e-5)
    gt = read_annotations(
        '../datasets/AICity_data/train/S03/c010/m6-full_annotation.xml')
    dataset = YoloDataset(video, gt, classes, transforms=detection_transform)
    data_loader = DataLoader(dataset,
                             batch_size=16,
                             shuffle=True,
                             num_workers=4)

    for epoch in tqdm(range(10), file=sys.stdout, desc='Fine tuning'):
        for images, targets in tqdm(data_loader,
                                    file=sys.stdout,
                                    desc='Running epoch'):
            if torch.cuda.is_available():
                images = images.cuda()
                targets = targets.cuda()

            optimizer.zero_grad()
            loss = model(images, targets)
            loss.backward()
            optimizer.step()

    print('Training finished. Saving weights...')
    model.save_weights('../weights/fine_tuned_yolo_freeze.weights')
    print('Saved weights')