def run_training(train_module, train_name): """Run a train scripts in train_settings. args: train_module: Name of module in the "train_settings/" folder. train_name: Name of the train settings file. """ # set single threads in opencv cv.setNumThreads(0) print('Training: {} {}'.format(train_module, train_name)) settings = ws_settings.Settings() if settings.env.workspace_dir == '': raise Exception('Setup your workspace_dir in "ltr/admin/local.py".') settings.module_name = train_module settings.script_name = train_name settings.project_path = 'ltr/{}/{}'.format(train_module, train_name) expr_module = importlib.import_module('ltr.train_settings.{}.{}'.format( train_module, train_name)) expr_func = getattr(expr_module, 'run') expr_func(settings)
def run_training(train_module, train_name, cudnn_benchmark=True): """Run a train scripts in train_settings. args: train_module: Name of module in the "train_settings/" folder. train_name: Name of the train settings file. cudnn_benchmark: Use cudnn benchmark or not (default is True). """ # This is needed to avoid strange crashes related to opencv cv.setNumThreads(0) torch.backends.cudnn.benchmark = cudnn_benchmark print('Training: {} {}'.format(train_module, train_name)) settings = ws_settings.Settings() settings.module_name = train_module settings.script_name = train_name settings.project_path = 'ltr/{}/{}'.format(train_module, train_name) expr_module = importlib.import_module('ltr.train_settings.{}.{}'.format( train_module, train_name)) expr_func = getattr(expr_module, 'run') expr_func(settings)
def load_pretrained(module, name, checkpoint=None, **kwargs): """Load a network trained using the LTR framework. This is useful when you want to initialize your new network with a previously trained model. args: module - Name of the train script module. I.e. the name of the folder in ltr/train_scripts. name - The name of the train_script. checkpoint - You can supply the checkpoint number or the full path to the checkpoint file (see load_network). **kwargs - These are passed to load_network (see that function). """ settings = ws_settings.Settings() network_dir = os.path.join(settings.env.workspace_dir, 'checkpoints', 'ltr', module, name) return load_network(network_dir=network_dir, checkpoint=checkpoint, **kwargs)
def run_training(train_module, train_name, cudnn_benchmark=True, args=None): """Run a train scripts in train_settings. args: train_module: Name of module in the "train_settings/" folder. train_name: Name of the train settings file. cudnn_benchmark: Use cudnn benchmark or not (default is True). """ # This is needed to avoid strange crashes related to opencv cv.setNumThreads(0) torch.backends.cudnn.benchmark = cudnn_benchmark print('Training: {} {}'.format(train_module, train_name)) settings = ws_settings.Settings() settings.module_name = train_module settings.script_name = train_name settings.project_path = 'ltr/{}/{}'.format(train_module, train_name) settings.samples_per_epoch = args.samples_per_epoch settings.use_pretrained_dimp = args.use_pretrained_dimp settings.pretrained_dimp50 = args.pretrained_dimp50 settings.load_model = args.load_model settings.fcot_model = args.fcot_model settings.train_cls_72_and_reg_init = args.train_cls_72_and_reg_init settings.train_reg_optimizer = args.train_reg_optimizer settings.train_cls_18 = args.train_cls_18 settings.total_epochs = args.total_epochs settings.lasot_rate = args.lasot_rate settings.devices_id = args.devices_id settings.batch_size = args.batch_size settings.num_workers = args.num_workers settings.norm_scale_coef = args.norm_scale_coef if args.workspace_dir is not None: settings.env.workspace_dir = args.workspace_dir settings.env.tensorboard_dir = settings.env.workspace_dir + '/tensorboard/' expr_module = importlib.import_module('ltr.train_settings.{}.{}'.format( train_module, train_name)) expr_func = getattr(expr_module, 'run') expr_func(settings)
def __init__(self, root=None, sequences=None, version='2017', split='train', multiobj=True, vis_threshold=10, image_loader=opencv_loader): """ args: root - Dataset root path. If unset, it uses the path in your local.py config. sequences - List of sequence names. Limit to a subset of sequences if not None. version - '2016' or '2017 split - Any name in DAVIS/ImageSets/<year> multiobj - Whether the dataset will return all objects in a sequence or multiple sequences with one object in each. vis_threshold - Minimum number of pixels required to consider a target object "visible". image_loader (jpeg4py_loader) - The function to read the images. jpeg4py (https://github.com/ajkxyz/jpeg4py) is used by default. """ if version == '2017': if split in ['train', 'val']: root = ws_settings.Settings( ).env.davis_dir if root is None else root elif split in ['test-dev']: root = ws_settings.Settings( ).env.davis_testdev_dir if root is None else root else: raise Exception('Unknown split {}'.format(split)) else: root = ws_settings.Settings( ).env.davis16_dir if root is None else root super().__init__(name='DAVIS', root=Path(root), version=version, split=split, multiobj=multiobj, vis_threshold=vis_threshold, image_loader=image_loader) dset_path = self.root self._jpeg_path = dset_path / 'JPEGImages' / '480p' self._anno_path = dset_path / 'Annotations' / '480p' meta_path = dset_path / "generated_meta.json" if meta_path.exists(): self.gmeta = VOSMeta(filename=meta_path) else: self.gmeta = VOSMeta.generate('DAVIS', self._jpeg_path, self._anno_path) self.gmeta.save(meta_path) if sequences is None: if self.split != 'all': fname = dset_path / 'ImageSets' / self.version / (self.split + '.txt') sequences = open(fname).read().splitlines() else: sequences = [ p for p in sorted(self._jpeg_path.glob("*")) if p.is_dir() ] self.sequence_names = sequences self._samples = [] for seq in sequences: obj_ids = self.gmeta.get_obj_ids(seq) if self.multiobj: # Multiple objects per sample self._samples.append((seq, obj_ids)) else: # One object per sample self._samples.extend([(seq, [obj_id]) for obj_id in obj_ids]) print("%s loaded." % self.get_name())
import math import torch import torch.nn as nn from collections import OrderedDict import ltr.models.target_classifier.linear_filter as target_clf import ltr.models.target_classifier.logistic_filter as occ_clf # import ltr.models.target_classifier.linear_filter_rgbd as target_clf import ltr.models.target_classifier.features as clf_features import ltr.models.target_classifier.initializer as clf_initializer import ltr.models.target_classifier.optimizer as clf_optimizer import ltr.models.bbreg as bbmodels import ltr.models.backbone as backbones from ltr import model_constructor import ltr.admin.settings as ws_settings settings = ws_settings.Settings() settings.depthaware_for_classiferonline = True settings.depthaware_for_classifer_init = True settings.depthaware_for_classifer_optimizer = False settings.depthaware_for_iounet = True settings.depthaware_alpha = 0.1 class DiMPnet_rgbd_blend1(nn.Module): """The DiMP network. args: feature_extractor: Backbone feature extractor network. Must return a dict of feature maps classifier: Target classification module. bb_regressor: Bounding box regression module. classification_layer: Name of the backbone feature layer to use for classification. bb_regressor_layer: Names of the backbone layers to use for bounding box regression.
def __init__(self, root=None, version='2019', split='train', cleanup=None, all_frames=False, sequences=None, multiobj=True, vis_threshold=10, image_loader=opencv_loader): """ args: root - Dataset root path. If unset, it uses the path in your local.py config. version - '2018' or '2019' split - 'test', 'train', 'valid', or 'jjtrain', 'jjvalid'. 'jjvalid' corresponds to a custom validation dataset consisting of 300 videos randomly sampled from the train set. 'jjtrain' contains the remaining videos used for training. cleanup - List of actions to take to to clean up known problems in the dataset. 'aspects': remove frames with weird aspect ratios, 'starts': fix up start frames from original meta data all_frames - Whether to use an "all_frames" split. sequences - List of sequence names. Limit to a subset of sequences if not None. multiobj - Whether the dataset will return all objects in a sequence or multiple sequences with one object in each. vis_threshold - Minimum number of pixels required to consider a target object "visible". image_loader - Image loader. """ root = ws_settings.Settings().env.youtubevos_dir if root is None else root super().__init__(name="YouTubeVOS", root=Path(root), version=version, split=split, multiobj=multiobj, vis_threshold=vis_threshold, image_loader=image_loader) split_folder = self.split if self.split.startswith("jj"): split_folder = "train" dset_path = self.root / self.version / split_folder # dset_path = self.root / split_folder self._anno_path = dset_path / 'Annotations' if all_frames: self._jpeg_path = self.root / self.version / (split_folder + "_all_frames") / 'JPEGImages' # self._jpeg_path = self.root / (split_folder + "_all_frames") / 'JPEGImages' else: self._jpeg_path = dset_path / 'JPEGImages' self.meta = YouTubeVOSMeta(dset_path) meta_path = dset_path / "generated_meta.json" if meta_path.exists(): self.gmeta = VOSMeta(filename=meta_path) else: self.gmeta = VOSMeta.generate('YouTubeVOS', self._jpeg_path, self._anno_path) self.gmeta.save(meta_path) if all_frames: self.gmeta.enable_all_frames(self._jpeg_path) if self.split not in ['train', 'valid', 'test']: self.gmeta.select_split('youtubevos', self.split) if sequences is None: sequences = self.gmeta.get_sequence_names() to_remove = set() cleanup = {} if cleanup is None else set(cleanup) if 'aspect' in cleanup: # Remove sequences with unusual aspect ratios for seq_name in sequences: a = self.gmeta.get_aspect_ratio(seq_name) if a < 1.45 or a > 1.9: to_remove.add(seq_name) if 'starts' in cleanup: # Fix incorrect start frames for some objects found with ytvos_start_frames_test() bad_start_frames = [("0e27472bea", '2', ['00055', '00060'], '00065'), ("5937b08d69", '4', ['00000'], '00005'), ("5e1ce354fd", '5', ['00010', '00015'], '00020'), ("7053e4f41e", '2', ['00000', '00005', '00010', '00015'], '00020'), ("720e3fa04c", '2', ['00050'], '00055'), ("c73c8e747f", '2', ['00035'], '00040')] for seq_name, obj_id, bad_frames, good_frame in bad_start_frames: # bad_frames is from meta.json included with the dataset # good_frame is from the generated meta - and the first actual frame where the object was seen. if seq_name in self.meta._data: frames = self.meta.object_frames(seq_name, obj_id) for f in bad_frames: frames.remove(f) assert frames[0] == good_frame sequences = [seq for seq in sequences if seq not in to_remove] self.sequence_names = sequences self._samples = [] for seq in sequences: obj_ids = self.meta.object_ids(seq) if self.multiobj: # Multiple objects per sample self._samples.append((seq, obj_ids)) else: # One object per sample self._samples.extend([(seq, [obj_id]) for obj_id in obj_ids]) print("%s loaded." % self.get_name()) if len(to_remove) > 0: print(" %d sequences were removed, (%d remaining)." % (len(to_remove), len(sequences)))