예제 #1
0
 def test_input_fn(model):
     model_helper.AddVideoInput(
         model,
         test_reader,
         batch_size=batch_per_device,
         length_rgb=args.clip_length_rgb,
         clip_per_video=1,
         decode_type=0,
         random_mirror=False,
         random_crop=False,
         sampling_rate_rgb=args.sampling_rate_rgb,
         scale_h=args.scale_h,
         scale_w=args.scale_w,
         crop_size=args.crop_size,
         video_res_type=args.video_res_type,
         short_edge=min(args.scale_h, args.scale_w),
         num_decode_threads=args.num_decode_threads,
         do_multi_label=args.multi_label,
         num_of_class=args.num_labels,
         input_type=args.input_type,
         length_of=args.clip_length_of,
         sampling_rate_of=args.sampling_rate_of,
         frame_gap_of=args.frame_gap_of,
         do_flow_aggregation=args.do_flow_aggregation,
         flow_data_type=args.flow_data_type,
         get_rgb=(args.input_type == 0),
         get_optical_flow=(args.input_type == 1),
         get_video_id=args.get_video_id,
         use_local_file=args.use_local_file,
     )
예제 #2
0
 def add_video_input(model):
     model_helper.AddVideoInput(
         model,
         train_reader,
         batch_size=batch_per_device,
         length_rgb=args.clip_length_rgb,
         clip_per_video=1,
         random_mirror=True,
         decode_type=0,
         sampling_rate_rgb=args.sampling_rate_rgb,
         scale_h=args.scale_h,
         scale_w=args.scale_w,
         crop_size=args.crop_size,
         video_res_type=args.video_res_type,
         short_edge=min(args.scale_h, args.scale_w),
         num_decode_threads=args.num_decode_threads,
         do_multi_label=args.multi_label,
         num_of_class=args.num_labels,
         random_crop=True,
         input_type=args.input_type,
         length_of=args.clip_length_of,
         sampling_rate_of=args.sampling_rate_of,
         frame_gap_of=args.frame_gap_of,
         do_flow_aggregation=args.do_flow_aggregation,
         flow_data_type=args.flow_data_type,
         get_rgb=(args.input_type == 0 or args.input_type >= 3),
         get_optical_flow=(args.input_type == 1 or args.input_type >= 4),
         get_logmels=(args.input_type >= 2),
         get_video_id=args.get_video_id,
         jitter_scales=[int(n) for n in args.jitter_scales.split(',')],
         use_local_file=args.use_local_file,
     )
예제 #3
0
 def test_input_fn(model):
     model_helper.AddVideoInput(
         test_model,
         test_reader,
         batch_size=args.batch_size,
         clip_per_video=args.clip_per_video,
         decode_type=1,
         length_rgb=args.clip_length_rgb,
         sampling_rate_rgb=args.sampling_rate_rgb,
         scale_h=args.scale_h,
         scale_w=args.scale_w,
         crop_size=args.crop_size,
         num_decode_threads=4,
         num_of_class=args.num_labels,
         random_mirror=False,
         random_crop=False,
         input_type=args.input_type,
         length_of=args.clip_length_of,
         sampling_rate_of=args.sampling_rate_of,
         frame_gap_of=args.frame_gap_of,
         do_flow_aggregation=args.do_flow_aggregation,
         flow_data_type=args.flow_data_type,
         get_rgb=(args.input_type == 0),
         get_optical_flow=(args.input_type == 1),
         get_video_id=args.get_video_id,
         use_local_file=args.use_local_file,
     )
예제 #4
0
 def test_input_fn(model):
     model_helper.AddVideoInput(test_data_loader, test_reader,
                                **video_input_args)
예제 #5
0
 def input_fn(model):
     model_helper.AddVideoInput(model, reader, **video_input_args)