def run(settings):
    # Most common settings are assigned in the settings struct
    settings.device = 'cuda'
    settings.description = 'TransT with default settings.'
    settings.batch_size = 24 # 38
    settings.num_workers = min(settings.batch_size, 8)
    settings.multi_gpu = False
    settings.print_interval = 1
    settings.normalize_mean = [0.485, 0.456, 0.406]
    settings.normalize_std = [0.229, 0.224, 0.225]
    settings.search_area_factor = 4.0
    settings.template_area_factor = 2.0
    settings.search_feature_sz = 32
    settings.template_feature_sz = 16
    settings.search_sz = settings.search_feature_sz * 8
    settings.temp_sz = settings.template_feature_sz * 8
    settings.center_jitter_factor = {'search': 3, 'template': 0}
    settings.scale_jitter_factor = {'search': 0.25, 'template': 0}
    settings.sequence_length = 8
    settings.search_gap = 5
    settings.init_ckpt = "pytracking/networks/transt.pth"

    # Transformer
    settings.position_embedding = 'sine'
    settings.hidden_dim = 256
    settings.dropout = 0.1
    settings.nheads = 8
    settings.dim_feedforward = 2048
    settings.featurefusion_layers = 4

    settings.sigma = 1 / 4 / 5.
    settings.kernel = 4
    settings.feature = 32  # 18
    settings.output_sz = 256  # settings.feature * 16
    settings.end_pad_if_even = False
    settings.label_function_params = True

    # Train datasets
    # lasot_train = Lasot(settings.env.lasot_dir, split='train')
    got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')
    # trackingnet_train = TrackingNet(settings.env.trackingnet_dir, set_ids=list(range(4)))
    # coco_train = MSCOCOSeq(settings.env.coco_dir)

    # The joint augmentation transform, that is applied to the pairs jointly
    transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))

    # The augmentation transform applied to the training set (individually to each image in the pair)
    transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
                                    tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))

    # Data processing to do on the training pairs
    data_processing_train = processing.TransTProcessing(search_area_factor=settings.search_area_factor,
                                                      template_area_factor = settings.template_area_factor,
                                                      search_sz=settings.search_sz,
                                                      temp_sz=settings.temp_sz,
                                                      center_jitter_factor=settings.center_jitter_factor,
                                                      scale_jitter_factor=settings.scale_jitter_factor,
                                                      mode='sequence',
                                                      transform=transform_train,
                                                      label_function_params=settings.label_function_params,
                                                      joint_transform=transform_joint)

    # The sampler for training
    dataset_train = sampler.TransTSampler([got10k_train], [1], samples_per_epoch=1000*settings.batch_size, max_gap=100, processing=data_processing_train, num_search_frames=settings.sequence_length, frame_sample_mode="causal")
    # dataset_train = sampler.TransTSampler([lasot_train, got10k_train, coco_train, trackingnet_train], [1,1,1,1], samples_per_epoch=1000*settings.batch_size, max_gap=100, processing=data_processing_train)

    # The loader for training
    loader_train = LTRLoader('train', dataset_train, training=True, batch_size=settings.batch_size, num_workers=settings.num_workers,
                             shuffle=True, drop_last=True, stack_dim=0)

    # Create network and actor
    model = transt_models.transt_resnet50(settings)

    # Wrap the network for multi GPU training
    if settings.multi_gpu:
        model = MultiGPU(model, dim=0)

    objective = transt_models.transt_loss(settings)
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    actor = actors.CircuitTranstActor(net=model, objective=objective)

    # Optimizer
    # Change learning rate forthe Q we have changed and the RNN and the readout
    #         q = self.mix_q(torch.cat([q, self.mix_norm(exc)], -1))
    #        self.class_embed_new = MLP(hidden_dim * 2, hidden_dim, num_classes + 1, 3)
    #         self.bbox_embed_new = MLP(hidden_dim * 2, hidden_dim, 4, 3)
    param_dicts = [
        # {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad and "circuit" not in n and "mix" not in n and "new" not in n]},
        # {
        #     "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
        #     "lr": 1e-5,
        # },
        {
            "params": [p for n, p in model.named_parameters() if "circuit" in n or "mix" in n or "new" in n],  #  or "class_embed" in n or "bbox_embed" in n],
            "lr": 1e-3,
        },
        # {
        #     "params": [p for n, p in model.named_parameters() if "class_embed" in n or "bbox_embed" in n],  # if "circuit" in n or "mix" in n or "new" in n or "class_embed" in n or "bbox_embed" in n],
        #     "lr": 1e-5,
        # },
    ]
    for n, p in model.named_parameters():
        if "circuit" in n or "mix" in n or "new" in n or "class_embed" in n or "bbox_embed" in n or "decoder" in n:
            pass
        else:
            p.requires_grad = False
            print("Removing grad on {}".format(n))
    optimizer = torch.optim.AdamW(param_dicts, lr=1e-5,  # lr=1e-4,
                                  weight_decay=1e-4)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 500)

    # Create trainer
    trainer = LTRTrainer(actor, [loader_train], optimizer, settings, lr_scheduler)

    # Run training (set fail_safe=False if you are debugging)
    trainer.train(1000, load_latest=True, fail_safe=True)
Beispiel #2
0
def run(settings):
    settings.description = 'Default train settings for training full network'
    settings.batch_size = 20
    settings.num_workers = 8
    settings.multi_gpu = True
    settings.print_interval = 1
    settings.normalize_mean = [102.9801, 115.9465, 122.7717]
    settings.normalize_std = [1.0, 1.0, 1.0]

    settings.feature_sz = (52, 30)

    # Settings used for generating the image crop input to the network. See documentation of LWTLProcessing class in
    # ltr/data/processing.py for details.
    settings.output_sz = (settings.feature_sz[0] * 16,
                          settings.feature_sz[1] * 16
                          )  # Size of input image crop
    settings.search_area_factor = 5.0
    settings.crop_type = 'inside_major'
    settings.max_scale_change = None

    settings.center_jitter_factor = {'train': 3, 'test': (5.5, 4.5)}
    settings.scale_jitter_factor = {'train': 0.25, 'test': 0.5}

    # Datasets
    ytvos_train = YouTubeVOS(version="2019", multiobj=False, split='jjtrain')
    davis_train = Davis(version='2017', multiobj=False, split='train')

    ytvos_val = YouTubeVOS(version="2019", multiobj=False, split='jjvalid')

    # Data transform
    transform_joint = tfm.Transform(tfm.ToBGR(),
                                    tfm.ToGrayscale(probability=0.05),
                                    tfm.RandomHorizontalFlip(probability=0.5))

    transform_train = tfm.Transform(
        tfm.RandomAffine(p_flip=0.0,
                         max_rotation=15.0,
                         max_shear=0.0,
                         max_ar_factor=0.0,
                         max_scale=0.2,
                         pad_amount=0),
        tfm.ToTensorAndJitter(0.2, normalize=False),
        tfm.Normalize(mean=settings.normalize_mean,
                      std=settings.normalize_std))

    transform_val = tfm.Transform(
        tfm.ToTensorAndJitter(0.0, normalize=False),
        tfm.Normalize(mean=settings.normalize_mean,
                      std=settings.normalize_std))

    data_processing_train = processing.LWLProcessing(
        search_area_factor=settings.search_area_factor,
        output_sz=settings.output_sz,
        center_jitter_factor=settings.center_jitter_factor,
        scale_jitter_factor=settings.scale_jitter_factor,
        mode='sequence',
        crop_type=settings.crop_type,
        max_scale_change=settings.max_scale_change,
        transform=transform_train,
        joint_transform=transform_joint,
        new_roll=True)

    data_processing_val = processing.LWLProcessing(
        search_area_factor=settings.search_area_factor,
        output_sz=settings.output_sz,
        center_jitter_factor=settings.center_jitter_factor,
        scale_jitter_factor=settings.scale_jitter_factor,
        mode='sequence',
        crop_type=settings.crop_type,
        max_scale_change=settings.max_scale_change,
        transform=transform_val,
        joint_transform=transform_joint,
        new_roll=True)

    # Train sampler and loader
    dataset_train = sampler.LWLSampler([ytvos_train, davis_train], [6, 1],
                                       samples_per_epoch=settings.batch_size *
                                       1000,
                                       max_gap=100,
                                       num_test_frames=3,
                                       num_train_frames=1,
                                       processing=data_processing_train)
    dataset_val = sampler.LWLSampler([ytvos_val], [1],
                                     samples_per_epoch=settings.batch_size *
                                     100,
                                     max_gap=100,
                                     num_test_frames=3,
                                     num_train_frames=1,
                                     processing=data_processing_val)

    loader_train = LTRLoader('train',
                             dataset_train,
                             training=True,
                             num_workers=settings.num_workers,
                             stack_dim=1,
                             batch_size=settings.batch_size)

    loader_val = LTRLoader('val',
                           dataset_val,
                           training=False,
                           num_workers=settings.num_workers,
                           epoch_interval=5,
                           stack_dim=1,
                           batch_size=settings.batch_size)

    # Network
    net = lwl_networks.steepest_descent_resnet50(
        filter_size=3,
        num_filters=16,
        optim_iter=5,
        backbone_pretrained=True,
        out_feature_dim=512,
        frozen_backbone_layers=['conv1', 'bn1', 'layer1'],
        label_encoder_dims=(16, 32, 64),
        use_bn_in_label_enc=False,
        clf_feat_blocks=0,
        final_conv=True,
        backbone_type='mrcnn')

    base_net = network_loading.load_trained_network(
        settings.env.workspace_dir,
        'ltr/lwl/lwl_stage1/LWTLNet_ep0070.pth.tar')

    net.load_state_dict(base_net.state_dict())

    # Wrap the network for multi GPU training
    if settings.multi_gpu:
        net = MultiGPU(net, dim=1)

    # Loss function
    objective = {
        'segm': LovaszSegLoss(per_image=False),
    }

    loss_weight = {'segm': 100.0}

    actor = segm_actors.LWLActor(net=net,
                                 objective=objective,
                                 loss_weight=loss_weight,
                                 num_refinement_iter=2,
                                 disable_all_bn=True)

    # Optimizer
    optimizer = optim.Adam(
        [{
            'params': actor.net.target_model.filter_initializer.parameters(),
            'lr': 5e-5
        }, {
            'params': actor.net.target_model.filter_optimizer.parameters(),
            'lr': 2e-5
        }, {
            'params': actor.net.target_model.feature_extractor.parameters(),
            'lr': 2e-5
        }, {
            'params': actor.net.decoder.parameters(),
            'lr': 2e-5
        }, {
            'params': actor.net.label_encoder.parameters(),
            'lr': 2e-5
        }, {
            'params': actor.net.feature_extractor.layer2.parameters(),
            'lr': 2e-5
        }, {
            'params': actor.net.feature_extractor.layer3.parameters(),
            'lr': 2e-5
        }, {
            'params': actor.net.feature_extractor.layer4.parameters(),
            'lr': 2e-5
        }],
        lr=2e-4)

    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [25, 75],
                                                  gamma=0.2)

    trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer,
                         settings, lr_scheduler)

    trainer.train(80, load_latest=True, fail_safe=True)
Beispiel #3
0
def run(settings):
    settings.description = 'Default train settings for training VOS with box initialization.'
    settings.batch_size = 8
    settings.num_workers = 4
    settings.multi_gpu = False
    settings.print_interval = 1
    settings.normalize_mean = [102.9801, 115.9465, 122.7717]
    settings.normalize_std = [1.0, 1.0, 1.0]

    settings.feature_sz = (52, 30)
    settings.output_sz = (settings.feature_sz[0] * 16,
                          settings.feature_sz[1] * 16)
    settings.search_area_factor = 5.0
    settings.crop_type = 'inside_major'
    settings.max_scale_change = None
    settings.device = "cuda:0"
    settings.center_jitter_factor = {'train': 3, 'test': (5.5, 4.5)}
    settings.scale_jitter_factor = {'train': 0.25, 'test': 0.5}

    settings.min_target_area = 500

    ytvos_train = YouTubeVOS(version="2019", multiobj=False, split='jjtrain')
    ytvos_valid = YouTubeVOS(version="2019", multiobj=False, split='jjvalid')
    coco_train = MSCOCOSeq()

    # Data transform
    transform_joint = tfm.Transform(tfm.ToBGR(),
                                    tfm.ToGrayscale(probability=0.05),
                                    tfm.RandomHorizontalFlip(probability=0.5))

    transform_train = tfm.Transform(
        tfm.ToTensorAndJitter(0.2, normalize=False),
        tfm.Normalize(mean=settings.normalize_mean,
                      std=settings.normalize_std))

    transform_val = tfm.Transform(
        tfm.ToTensorAndJitter(0.0, normalize=False),
        tfm.Normalize(mean=settings.normalize_mean,
                      std=settings.normalize_std))

    data_processing_train = processing.LWLProcessing(
        search_area_factor=settings.search_area_factor,
        output_sz=settings.output_sz,
        center_jitter_factor=settings.center_jitter_factor,
        scale_jitter_factor=settings.scale_jitter_factor,
        mode='sequence',
        crop_type=settings.crop_type,
        max_scale_change=settings.max_scale_change,
        transform=transform_train,
        joint_transform=transform_joint,
        new_roll=True)

    data_processing_val = processing.LWLProcessing(
        search_area_factor=settings.search_area_factor,
        output_sz=settings.output_sz,
        center_jitter_factor=settings.center_jitter_factor,
        scale_jitter_factor=settings.scale_jitter_factor,
        mode='sequence',
        crop_type=settings.crop_type,
        max_scale_change=settings.max_scale_change,
        transform=transform_val,
        joint_transform=transform_joint,
        new_roll=True)
    # Train sampler and loader
    dataset_train = sampler.LWLSampler([ytvos_train, coco_train], [1, 1],
                                       samples_per_epoch=settings.batch_size *
                                       1000,
                                       max_gap=100,
                                       num_test_frames=1,
                                       num_train_frames=1,
                                       processing=data_processing_train)
    dataset_val = sampler.LWLSampler([ytvos_valid], [1],
                                     samples_per_epoch=settings.batch_size *
                                     100,
                                     max_gap=100,
                                     num_test_frames=1,
                                     num_train_frames=1,
                                     processing=data_processing_val)

    loader_train = LTRLoader('train',
                             dataset_train,
                             training=True,
                             num_workers=settings.num_workers,
                             stack_dim=1,
                             batch_size=settings.batch_size)
    loader_val = LTRLoader('val',
                           dataset_val,
                           training=False,
                           num_workers=settings.num_workers,
                           epoch_interval=5,
                           stack_dim=1,
                           batch_size=settings.batch_size)

    net = lwt_box_networks.steepest_descent_resnet50(
        filter_size=3,
        num_filters=16,
        optim_iter=5,
        backbone_pretrained=True,
        out_feature_dim=512,
        frozen_backbone_layers=['conv1', 'bn1', 'layer1'],
        label_encoder_dims=(16, 32, 64),
        use_bn_in_label_enc=False,
        clf_feat_blocks=0,
        final_conv=True,
        backbone_type='mrcnn',
        box_label_encoder_dims=(
            64,
            64,
        ),
        final_bn=False)

    base_net_weights = network_loading.load_trained_network(
        settings.env.workspace_dir,
        'ltr/lwl/lwl_stage2/LWTLNet_ep0080.pth.tar')

    # Copy weights
    net.feature_extractor.load_state_dict(
        base_net_weights.feature_extractor.state_dict())
    net.target_model.load_state_dict(
        base_net_weights.target_model.state_dict())
    net.decoder.load_state_dict(base_net_weights.decoder.state_dict())
    net.label_encoder.load_state_dict(
        base_net_weights.label_encoder.state_dict())

    # Wrap the network for multi GPU training
    if settings.multi_gpu:
        net = MultiGPU(net, dim=1)

    objective = {
        'segm': LovaszSegLoss(per_image=False),
    }

    loss_weight = {
        'segm': 100.0,
        'segm_box': 10.0,
        'segm_train': 10,
    }

    actor = lwtl_actors.LWLBoxActor(net=net,
                                    objective=objective,
                                    loss_weight=loss_weight)

    # Optimizer
    optimizer = optim.Adam([{
        'params': actor.net.box_label_encoder.parameters(),
        'lr': 1e-3
    }],
                           lr=2e-4)

    lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                             step_size=20,
                                             gamma=0.2)

    trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer,
                         settings, lr_scheduler)

    trainer.train(50, load_latest=True, fail_safe=True)
Beispiel #4
0
def run(settings):
    settings.description = 'Default train settings for DiMP with ResNet50 as backbone.'
    settings.batch_size = 10
    settings.num_workers = 8
    settings.multi_gpu = False
    settings.print_interval = 1
    settings.normalize_mean = [0.485, 0.456, 0.406]
    settings.normalize_std = [0.229, 0.224, 0.225]
    settings.search_area_factor = 5.0
    settings.output_sigma_factor = 1/4
    settings.target_filter_sz = 4
    settings.feature_sz = 18
    settings.output_sz = settings.feature_sz * 16
    settings.center_jitter_factor = {'train': 3, 'test': 4.5}
    settings.scale_jitter_factor = {'train': 0.25, 'test': 0.5}
    settings.hinge_threshold = 0.05
    # settings.print_stats = ['Loss/total', 'Loss/iou', 'ClfTrain/clf_ce', 'ClfTrain/test_loss']

    # Train datasets
    lasot_train = Lasot(settings.env.lasot_dir, split='train')
    got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')
    trackingnet_train = TrackingNet(settings.env.trackingnet_dir, set_ids=list(range(4)))
    coco_train = MSCOCOSeq(settings.env.coco_dir)

    # Validation datasets
    got10k_val = Got10k(settings.env.got10k_dir, split='votval')


    # Data transform
    transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))

    transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
                                    tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))

    transform_val = tfm.Transform(tfm.ToTensor(),
                                  tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))

    # The tracking pairs processing module
    output_sigma = settings.output_sigma_factor / settings.search_area_factor
    proposal_params = {'min_iou': 0.1, 'boxes_per_frame': 8, 'sigma_factor': [0.01, 0.05, 0.1, 0.2, 0.3]}
    label_params = {'feature_sz': settings.feature_sz, 'sigma_factor': output_sigma, 'kernel_sz': settings.target_filter_sz}
    data_processing_train = processing.DiMPProcessing(search_area_factor=settings.search_area_factor,
                                                      output_sz=settings.output_sz,
                                                      center_jitter_factor=settings.center_jitter_factor,
                                                      scale_jitter_factor=settings.scale_jitter_factor,
                                                      mode='sequence',
                                                      proposal_params=proposal_params,
                                                      label_function_params=label_params,
                                                      transform=transform_train,
                                                      joint_transform=transform_joint)

    data_processing_val = processing.DiMPProcessing(search_area_factor=settings.search_area_factor,
                                                    output_sz=settings.output_sz,
                                                    center_jitter_factor=settings.center_jitter_factor,
                                                    scale_jitter_factor=settings.scale_jitter_factor,
                                                    mode='sequence',
                                                    proposal_params=proposal_params,
                                                    label_function_params=label_params,
                                                    transform=transform_val,
                                                    joint_transform=transform_joint)

    # Train sampler and loader
    dataset_train = sampler.DiMPSampler([lasot_train, got10k_train, trackingnet_train, coco_train], [0.25,1,1,1],
                                        samples_per_epoch=26000, max_gap=30, num_test_frames=3, num_train_frames=3,
                                        processing=data_processing_train)

    loader_train = LTRLoader('train', dataset_train, training=True, batch_size=settings.batch_size, num_workers=settings.num_workers,
                             shuffle=True, drop_last=True, stack_dim=1)

    # Validation samplers and loaders
    dataset_val = sampler.DiMPSampler([got10k_val], [1], samples_per_epoch=5000, max_gap=30,
                                      num_test_frames=3, num_train_frames=3,
                                      processing=data_processing_val)

    loader_val = LTRLoader('val', dataset_val, training=False, batch_size=settings.batch_size, num_workers=settings.num_workers,
                           shuffle=False, drop_last=True, epoch_interval=5, stack_dim=1)

    # Create network and actor
    net = dimpnet.dimpnet50(filter_size=settings.target_filter_sz, backbone_pretrained=True, optim_iter=5,
                            clf_feat_norm=True, clf_feat_blocks=0, final_conv=True, out_feature_dim=512,
                            optim_init_step=0.9, optim_init_reg=0.1,
                            init_gauss_sigma=output_sigma * settings.feature_sz, num_dist_bins=100,
                            bin_displacement=0.1, mask_init_factor=3.0, target_mask_act='sigmoid', score_act='relu')

    # Wrap the network for multi GPU training
    if settings.multi_gpu:
        net = MultiGPU(net, dim=1)

    objective = {'iou': nn.MSELoss(), 'test_clf': ltr_losses.LBHinge(threshold=settings.hinge_threshold)}

    loss_weight = {'iou': 1, 'test_clf': 100, 'test_init_clf': 100, 'test_iter_clf': 400}

    actor = actors.DiMPActor(net=net, objective=objective, loss_weight=loss_weight)

    # Optimizer
    optimizer = optim.Adam([{'params': actor.net.classifier.filter_initializer.parameters(), 'lr': 5e-5},
                            {'params': actor.net.classifier.filter_optimizer.parameters(), 'lr': 5e-4},
                            {'params': actor.net.classifier.feature_extractor.parameters(), 'lr': 5e-5},
                            {'params': actor.net.bb_regressor.parameters()},
                            {'params': actor.net.feature_extractor.parameters(), 'lr': 2e-5}],
                           lr=2e-4)

    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.2)

    trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler)

    trainer.train(50, load_latest=True, fail_safe=True)
Beispiel #5
0
def run(settings):
    settings.description = 'Default train settings with backbone weights fixed. We initialize the backbone ResNet with ' \
                           'pre-trained Mask-RCNN weights. These weights can be obtained from ' \
                           'https://drive.google.com/file/d/12pVHmhqtxaJ151dZrXN1dcgUa7TuAjdA/view?usp=sharing. ' \
                           'Download and save these weights in env_settings.pretrained_networks directory'
    settings.batch_size = 20
    settings.num_workers = 8
    settings.multi_gpu = True
    settings.print_interval = 1
    settings.normalize_mean = [102.9801, 115.9465, 122.7717]
    settings.normalize_std = [1.0, 1.0, 1.0]

    settings.feature_sz = (52, 30)

    # Settings used for generating the image crop input to the network. See documentation of LWTLProcessing class in
    # ltr/data/processing.py for details.
    settings.output_sz = (settings.feature_sz[0] * 16,
                          settings.feature_sz[1] * 16
                          )  # Size of input image crop
    settings.search_area_factor = 5.0
    settings.crop_type = 'inside_major'
    settings.max_scale_change = None

    settings.center_jitter_factor = {'train': 3, 'test': (5.5, 4.5)}
    settings.scale_jitter_factor = {'train': 0.25, 'test': 0.5}

    # Datasets
    ytvos_train = YouTubeVOS(version="2019", multiobj=False, split='jjtrain')
    davis_train = Davis(version='2017', multiobj=False, split='train')

    ytvos_val = YouTubeVOS(version="2019", multiobj=False, split='jjvalid')

    # Data transform
    transform_joint = tfm.Transform(tfm.ToBGR(),
                                    tfm.ToGrayscale(probability=0.05),
                                    tfm.RandomHorizontalFlip(probability=0.5))

    transform_train = tfm.Transform(
        tfm.RandomAffine(p_flip=0.0,
                         max_rotation=15.0,
                         max_shear=0.0,
                         max_ar_factor=0.0,
                         max_scale=0.2,
                         pad_amount=0),
        tfm.ToTensorAndJitter(0.2, normalize=False),
        tfm.Normalize(mean=settings.normalize_mean,
                      std=settings.normalize_std))

    transform_val = tfm.Transform(
        tfm.ToTensorAndJitter(0.0, normalize=False),
        tfm.Normalize(mean=settings.normalize_mean,
                      std=settings.normalize_std))

    data_processing_train = processing.LWLProcessing(
        search_area_factor=settings.search_area_factor,
        output_sz=settings.output_sz,
        center_jitter_factor=settings.center_jitter_factor,
        scale_jitter_factor=settings.scale_jitter_factor,
        mode='sequence',
        crop_type=settings.crop_type,
        max_scale_change=settings.max_scale_change,
        transform=transform_train,
        joint_transform=transform_joint,
        new_roll=True)

    data_processing_val = processing.LWLProcessing(
        search_area_factor=settings.search_area_factor,
        output_sz=settings.output_sz,
        center_jitter_factor=settings.center_jitter_factor,
        scale_jitter_factor=settings.scale_jitter_factor,
        mode='sequence',
        crop_type=settings.crop_type,
        max_scale_change=settings.max_scale_change,
        transform=transform_val,
        joint_transform=transform_joint,
        new_roll=True)

    # Train sampler and loader
    dataset_train = sampler.LWLSampler([ytvos_train, davis_train], [6, 1],
                                       samples_per_epoch=settings.batch_size *
                                       1000,
                                       max_gap=100,
                                       num_test_frames=3,
                                       num_train_frames=1,
                                       processing=data_processing_train)
    dataset_val = sampler.LWLSampler([ytvos_val], [1],
                                     samples_per_epoch=settings.batch_size *
                                     100,
                                     max_gap=100,
                                     num_test_frames=3,
                                     num_train_frames=1,
                                     processing=data_processing_val)

    loader_train = LTRLoader('train',
                             dataset_train,
                             training=True,
                             num_workers=settings.num_workers,
                             stack_dim=1,
                             batch_size=settings.batch_size)

    loader_val = LTRLoader('val',
                           dataset_val,
                           training=False,
                           num_workers=settings.num_workers,
                           epoch_interval=5,
                           stack_dim=1,
                           batch_size=settings.batch_size)

    # Network
    net = lwl_networks.steepest_descent_resnet50(filter_size=3,
                                                 num_filters=16,
                                                 optim_iter=5,
                                                 backbone_pretrained=True,
                                                 out_feature_dim=512,
                                                 frozen_backbone_layers=[
                                                     'conv1', 'bn1', 'layer1',
                                                     'layer2', 'layer3',
                                                     'layer4'
                                                 ],
                                                 label_encoder_dims=(16, 32,
                                                                     64),
                                                 use_bn_in_label_enc=False,
                                                 clf_feat_blocks=0,
                                                 final_conv=True,
                                                 backbone_type='mrcnn')

    # Load pre-trained maskrcnn weights
    weights_path = os.path.join(settings.env.pretrained_networks,
                                'e2e_mask_rcnn_R_50_FPN_1x_converted.pkl')
    pretrained_weights = torch.load(weights_path)

    net.feature_extractor.load_state_dict(pretrained_weights)

    # Wrap the network for multi GPU training
    if settings.multi_gpu:
        net = MultiGPU(net, dim=1)

    # Loss function
    objective = {
        'segm': LovaszSegLoss(per_image=False),
    }

    loss_weight = {'segm': 100.0}

    actor = segm_actors.LWLActor(net=net,
                                 objective=objective,
                                 loss_weight=loss_weight,
                                 num_refinement_iter=2,
                                 disable_all_bn=True)

    # Optimizer
    optimizer = optim.Adam(
        [{
            'params': actor.net.target_model.filter_initializer.parameters(),
            'lr': 5e-5
        }, {
            'params': actor.net.target_model.filter_optimizer.parameters(),
            'lr': 1e-4
        }, {
            'params': actor.net.target_model.feature_extractor.parameters(),
            'lr': 2e-5
        }, {
            'params': actor.net.decoder.parameters(),
            'lr': 1e-4
        }, {
            'params': actor.net.label_encoder.parameters(),
            'lr': 2e-4
        }],
        lr=2e-4)

    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                  milestones=[
                                                      40,
                                                  ],
                                                  gamma=0.2)

    trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer,
                         settings, lr_scheduler)

    trainer.train(70, load_latest=True, fail_safe=True)
Beispiel #6
0
def run(settings):
    # Most common settings are assigned in the settings struct
    debug = False
    if debug:
        settings.batch_size = 4  # 8  # 4  # 120  # 70 # 38
        settings.num_workers = 0  # 24  # 30  # 10  # 35  # 30 min(settings.batch_size, 16)
        settings.multi_gpu = False  # True  # True  # True #  True  # True  # True  # True

    else:
        settings.batch_size = 38  # 8  # 4  # 120  # 70 # 38
        settings.num_workers = 20  # 24  # 30  # 10  # 35  # 30 min(settings.batch_size, 16)
        settings.multi_gpu = True  # True  # True  # True #  True  # True  # True  # True

    settings.device = 'cuda'
    settings.description = 'TransT with default settings.'
    settings.print_interval = 1
    settings.normalize_mean = [0.485, 0.456, 0.406]
    settings.normalize_std = [0.229, 0.224, 0.225]
    settings.search_area_factor = 6.  # 4.0
    settings.template_area_factor = 2.
    settings.search_feature_sz = 32
    settings.template_feature_sz = 16
    settings.search_sz = settings.search_feature_sz * 8
    settings.temp_sz = settings.template_feature_sz * 8
    settings.center_jitter_factor = {'search': 2.0, 'template': 0}  # 3
    settings.scale_jitter_factor = {'search': 0.05, 'template': 0}  # 0.25
    settings.sequence_length = 34  # 30  # 64 NEXT  # Same as PT
    settings.rand = False
    # settings.search_gap = 1  # Depreciated
    settings.init_ckpt = "pytracking/networks/transt.pth"

    # Transformer
    settings.position_embedding = 'sine'
    settings.hidden_dim = 256
    settings.dropout = 0.1
    settings.nheads = 8
    settings.dim_feedforward = 2048
    settings.featurefusion_layers = 4

    # settings.sigma = 1 / 4 / 5.
    # settings.kernel = 4
    # settings.feature = 32  # 18
    # settings.output_sz = 256  # settings.feature * 16
    # settings.end_pad_if_even = False
    # settings.label_function_params = True

    settings.move_data_to_gpu = True

    # Train datasets
    lasot_train = Lasot(settings.env.lasot_dir, split='train')
    got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')  # votval
    trackingnet_train = TrackingNet(settings.env.trackingnet_dir, set_ids=list(range(4)))
    # coco_train = MSCOCOSeq(settings.env.coco_dir)

    # The joint augmentation transform, that is applied to the pairs jointly
    transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))

    # The augmentation transform applied to the training set (individually to each image in the pair)
    # transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
    #                                 tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))
    transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
                                    # tfm.RandomHorizontalFlip(),
                                    # tfm.RandomAffine(p_flip=0.5, max_scale=1.5),
                                    # tfm.RandomBlur(1),
                                    tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))

    # Data processing to do on the training pairs
    data_processing_train = processing.TransTProcessing(search_area_factor=settings.search_area_factor,
                                                      template_area_factor = settings.template_area_factor,
                                                      search_sz=settings.search_sz,
                                                      temp_sz=settings.temp_sz,
                                                      center_jitter_factor=settings.center_jitter_factor,
                                                      scale_jitter_factor=settings.scale_jitter_factor,
                                                      mode='sequence',
                                                      joint=False,  # Whether or not to apply same transform to every image
                                                      transform=transform_train,
                                                      rand=settings.rand,
                                                      label_function_params=None,  # settings.label_function_params,
                                                      joint_transform=transform_joint)

    # The sampler for training
    # dataset_train = sampler.TransTSampler([got10k_train], [1], samples_per_epoch=1000*settings.batch_size, max_gap=100, processing=data_processing_train, num_search_frames=settings.sequence_length, frame_sample_mode="rnn_causal")
    # dataset_train = sampler.TransTSampler([got10k_train, trackingnet_train], [1, 1], samples_per_epoch=1000*settings.batch_size, max_gap=100, processing=data_processing_train, num_search_frames=settings.sequence_length, frame_sample_mode="interval")

    dataset_train = sampler.TransTSampler([lasot_train, got10k_train, trackingnet_train], [1,1,1], samples_per_epoch=1000*settings.batch_size, max_gap=100, processing=data_processing_train)

    # The loader for training
    loader_train = LTRLoader('train', dataset_train, training=True, batch_size=settings.batch_size, num_workers=settings.num_workers,
                             shuffle=True, drop_last=True, stack_dim=0, pin_memory=settings.move_data_to_gpu == False)

    # Create network and actor
    model = transt_models.transt_resnet50(settings)

    # Wrap the network for multi GPU training
    if settings.multi_gpu:
        model = MultiGPU(model, dim=0)

    objective = transt_models.transt_loss(settings)
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('number of params:', n_parameters)

    actor = actors.TranstActor(net=model, objective=objective)

    # Optimizer
    # Change learning rate forthe Q we have changed and the RNN and the readout
    #         q = self.mix_q(torch.cat([q, self.mix_norm(exc)], -1))
    #        self.class_embed_new = MLP(hidden_dim * 2, hidden_dim, num_classes + 1, 3)
    #         self.bbox_embed_new = MLP(hidden_dim * 2, hidden_dim, 4, 3)
    param_dicts = [
        {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
        {
            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
            "lr": 1e-5,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=1e-4,
                                  weight_decay=1e-4)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 500)

    # Create trainer
    trainer = LTRTrainer(actor, [loader_train], optimizer, settings, lr_scheduler)

    # Run training (set fail_safe=False if you are debugging)
    trainer.train(1000, load_latest=True, fail_safe=True)
Beispiel #7
0
def run(settings):
    settings.description = 'Default train settings for PrDiMP with ResNet18 as backbone.'
    settings.batch_size = 26
    settings.num_workers = 8
    settings.multi_gpu = False
    settings.print_interval = 1
    settings.normalize_mean = [0.485, 0.456, 0.406]
    settings.normalize_std = [0.229, 0.224, 0.225]
    settings.search_area_factor = 5.0
    settings.output_sigma_factor = 1/4
    settings.target_filter_sz = 4
    settings.feature_sz = 18
    settings.output_sz = settings.feature_sz * 16
    settings.center_jitter_factor = {'train': 3, 'test': 4.5}
    settings.scale_jitter_factor = {'train': 0.25, 'test': 0.5}
    settings.hinge_threshold = 0.05
    settings.print_stats = ['Loss/total', 'Loss/bb_ce', 'ClfTrain/clf_ce']

    # Train datasets
    lasot_train = Lasot(settings.env.lasot_dir, split='train')
    got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')
    trackingnet_train = TrackingNet(settings.env.trackingnet_dir, set_ids=list(range(4)))
    coco_train = MSCOCOSeq(settings.env.coco_dir)

    # Validation datasets
    got10k_val = Got10k(settings.env.got10k_dir, split='votval')


    # Data transform
    transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))

    transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
                                    tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))

    transform_val = tfm.Transform(tfm.ToTensor(),
                                  tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))

    # The tracking pairs processing module
    output_sigma = settings.output_sigma_factor / settings.search_area_factor
    proposal_params = {'boxes_per_frame': 128, 'gt_sigma': (0.05, 0.05), 'proposal_sigma': [(0.05, 0.05), (0.5, 0.5)]}
    label_params = {'feature_sz': settings.feature_sz, 'sigma_factor': output_sigma, 'kernel_sz': settings.target_filter_sz}
    label_density_params = {'feature_sz': settings.feature_sz, 'sigma_factor': output_sigma, 'kernel_sz': settings.target_filter_sz, 'normalize': True}

    data_processing_train = processing.KLDiMPProcessing(search_area_factor=settings.search_area_factor,
                                                        output_sz=settings.output_sz,
                                                        center_jitter_factor=settings.center_jitter_factor,
                                                        scale_jitter_factor=settings.scale_jitter_factor,
                                                        mode='sequence',
                                                        proposal_params=proposal_params,
                                                        label_function_params=label_params,
                                                        label_density_params=label_density_params,
                                                        transform=transform_train,
                                                        joint_transform=transform_joint)

    data_processing_val = processing.KLDiMPProcessing(search_area_factor=settings.search_area_factor,
                                                      output_sz=settings.output_sz,
                                                      center_jitter_factor=settings.center_jitter_factor,
                                                      scale_jitter_factor=settings.scale_jitter_factor,
                                                      mode='sequence',
                                                      proposal_params=proposal_params,
                                                      label_function_params=label_params,
                                                      label_density_params=label_density_params,
                                                      transform=transform_val,
                                                      joint_transform=transform_joint)

    # Train sampler and loader
    dataset_train = sampler.DiMPSampler([lasot_train, got10k_train, trackingnet_train, coco_train], [0.25,1,1,1],
                                        samples_per_epoch=26000, max_gap=200, num_test_frames=3, num_train_frames=3,
                                        processing=data_processing_train)

    loader_train = LTRLoader('train', dataset_train, training=True, batch_size=settings.batch_size, num_workers=settings.num_workers,
                             shuffle=True, drop_last=True, stack_dim=1)

    # Validation samplers and loaders
    dataset_val = sampler.DiMPSampler([got10k_val], [1], samples_per_epoch=5000, max_gap=200,
                                      num_test_frames=3, num_train_frames=3,
                                      processing=data_processing_val)

    loader_val = LTRLoader('val', dataset_val, training=False, batch_size=settings.batch_size, num_workers=settings.num_workers,
                           shuffle=False, drop_last=True, epoch_interval=5, stack_dim=1)

    # Create network and actor
    net = dimpnet.klcedimpnet18(filter_size=settings.target_filter_sz, backbone_pretrained=True, optim_iter=5,
                            clf_feat_norm=True, final_conv=True, optim_init_step=1.0, optim_init_reg=0.05, optim_min_reg=0.05,
                            gauss_sigma=output_sigma * settings.feature_sz, alpha_eps=0.05, normalize_label=True, init_initializer='zero')

    # Wrap the network for multi GPU training
    if settings.multi_gpu:
        net = MultiGPU(net, dim=1)

    objective = {'bb_ce': klreg_losses.KLRegression(), 'clf_ce': klreg_losses.KLRegressionGrid()}

    loss_weight = {'bb_ce': 0.0025, 'clf_ce': 0.25, 'clf_ce_init': 0.25, 'clf_ce_iter': 1.0}

    actor = tracking_actors.KLDiMPActor(net=net, objective=objective, loss_weight=loss_weight)

    # Optimizer
    optimizer = optim.Adam([{'params': actor.net.classifier.parameters(), 'lr': 1e-3},
                            {'params': actor.net.bb_regressor.parameters(), 'lr': 1e-3},
                            {'params': actor.net.feature_extractor.parameters()}],
                           lr=2e-4)

    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.2)

    trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler)

    trainer.train(50, load_latest=True, fail_safe=True)
Beispiel #8
0
def run(settings):
    settings.move_data_to_gpu = False
    settings.description = ''
    settings.batch_size = 10 # 10
    settings.test_sequence_length = 50
    settings.num_workers = 4
    settings.multi_gpu = True
    settings.print_interval = 1
    settings.normalize_mean = [0.485, 0.456, 0.406]
    settings.normalize_std = [0.229, 0.224, 0.225]
    settings.search_area_factor = 5.0
    settings.output_sigma_factor = 1/4
    settings.target_filter_sz = 4
    settings.feature_sz = 18
    settings.output_sz = settings.feature_sz * 16
    settings.center_jitter_param = {'train_mode': 'uniform', 'train_factor': 3.0, 'train_limit_motion': False,
                                    'test_mode': 'uniform', 'test_factor': 4.5, 'test_limit_motion': True}
    settings.scale_jitter_param = {'train_factor': 0.25, 'test_factor': 0.3}
    settings.hinge_threshold = 0.05
    settings.print_stats = ['Loss/total', 'Loss/raw/test_clf', 'Loss/raw/test_clf_acc', 'Loss/raw/dimp_clf_acc',
                            'Loss/raw/is_target', 'Loss/raw/is_target_after_prop',
                            'Loss/raw/test_seq_acc',
                            'Loss/raw/dimp_seq_acc']

    # lasot_train = Lasot(settings.env.lasot_dir, split='train')
    got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')
    # trackingnet_train = TrackingNet(settings.env.trackingnet_dir, set_ids=[0, 1, 2, 3, 4])

    # Validation datasets
    got10k_val = Got10k(settings.env.got10k_dir, split='votval')

    # Data transform
    transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))

    transform_train = tfm.Transform(tfm.ToTensorAndJitter(0.2),
                                    tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))

    transform_val = tfm.Transform(tfm.ToTensor(),
                                  tfm.Normalize(mean=settings.normalize_mean, std=settings.normalize_std))

    # The tracking pairs processing module
    output_sigma = settings.output_sigma_factor / settings.search_area_factor
    proposal_params = None

    label_params = {'feature_sz': settings.feature_sz, 'sigma_factor': output_sigma,
                    'kernel_sz': settings.target_filter_sz,
                    'end_pad_if_even': True}

    data_processing_train = processing.KYSProcessing(search_area_factor=settings.search_area_factor,
                                                     output_sz=settings.output_sz,
                                                     center_jitter_param=settings.center_jitter_param,
                                                     scale_jitter_param=settings.scale_jitter_param,
                                                     proposal_params=proposal_params,
                                                     label_function_params=label_params,
                                                     transform=transform_train,
                                                     joint_transform=transform_joint,
                                                     min_crop_inside_ratio=0.1)

    data_processing_val = processing.KYSProcessing(search_area_factor=settings.search_area_factor,
                                                   output_sz=settings.output_sz,
                                                   center_jitter_param=settings.center_jitter_param,
                                                   scale_jitter_param=settings.scale_jitter_param,
                                                   proposal_params=proposal_params,
                                                   label_function_params=label_params,
                                                   transform=transform_val,
                                                   joint_transform=transform_joint,
                                                   min_crop_inside_ratio=0.1)

    # Train sampler and loader
    sequence_sample_info = {'num_train_frames': 3, 'num_test_frames': settings.test_sequence_length,
                            'max_train_gap': 30, 'allow_missing_target': True, 'min_fraction_valid_frames': 0.5,
                            'mode': 'Sequence'}

    # dataset_train = sampler.KYSSampler([got10k_train, trackingnet_train, lasot_train],
    #                                    [0.3, 0.3, 0.25],
    #                                    samples_per_epoch=settings.batch_size * 150,
    #                                    sequence_sample_info=sequence_sample_info,
    #                                    processing=data_processing_train,
    #                                    sample_occluded_sequences=True)
    dataset_train = sampler.KYSSampler([got10k_train],
                                       [1],
                                       samples_per_epoch=settings.batch_size * 150,
                                       sequence_sample_info=sequence_sample_info,
                                       processing=data_processing_train,
                                       sample_occluded_sequences=True)

    loader_train = LTRLoader('train', dataset_train, training=True, batch_size=settings.batch_size,
                             num_workers=settings.num_workers,
                             shuffle=True, drop_last=True, stack_dim=1)

    # Validation samplers and loaders
    dataset_val = sampler.KYSSampler([got10k_val], [1], samples_per_epoch=1000,
                                     sequence_sample_info=sequence_sample_info, processing=data_processing_val,
                                     sample_occluded_sequences=True)

    loader_val = LTRLoader('val', dataset_val, training=False, batch_size=settings.batch_size,
                           num_workers=settings.num_workers,
                           shuffle=False, drop_last=True, epoch_interval=5, stack_dim=1)

    # load base dimp
    dimp_weights_path = os.path.join(settings.env.pretrained_networks, 'dimp50.pth')
    base_net, _ = network_loading.load_network(checkpoint=dimp_weights_path)

    net = kysnet_models.kysnet_res50(optim_iter=3, cv_kernel_size=3, cv_max_displacement=9,
                                     cv_stride=1, init_gauss_sigma=output_sigma * settings.feature_sz,
                                     train_feature_extractor=False, train_iounet=False, detach_length=0, state_dim=8,
                                     representation_predictor_dims=(16,), conf_measure='entropy',
                                     dimp_thresh=0.05)

    # Move pre-trained dimp weights
    net.backbone_feature_extractor.load_state_dict(base_net.feature_extractor.state_dict())
    net.dimp_classifier.load_state_dict(base_net.classifier.state_dict())
    net.bb_regressor.load_state_dict(base_net.bb_regressor.state_dict())

    if settings.multi_gpu:
        net = MultiGPU(net, dim=1)

    # To be safe
    for p in net.backbone_feature_extractor.parameters():
        p.requires_grad_(False)
    for p in net.dimp_classifier.parameters():
        p.requires_grad_(False)
    for p in net.bb_regressor.parameters():
        p.requires_grad_(False)

    objective = {'test_clf': ltr_losses.LBHingev2(threshold=settings.hinge_threshold, return_per_sequence=False),
                 'dimp_clf': ltr_losses.LBHingev2(threshold=settings.hinge_threshold, return_per_sequence=False),
                 'is_target': ltr_losses.IsTargetCellLoss(return_per_sequence=False),
                 'clf_acc': ltr_losses.TrackingClassificationAccuracy(threshold=0.25)}

    loss_weight = {'test_clf': 1.0*500, 'test_clf_orig': 50, 'is_target': 0.1*500, 'is_target_after_prop': 0.1*500}

    dimp_jitter_fn = DiMPScoreJittering(distractor_ratio=0.1, p_distractor=0.3, max_distractor_enhance_factor=1.3,
                                        min_distractor_enhance_factor=0.8)
    actor = actors.KYSActor(net=net, objective=objective, loss_weight=loss_weight,
                            dimp_jitter_fn=dimp_jitter_fn)

    optimizer = optim.Adam([{'params': actor.net.predictor.parameters(), 'lr': 1e-2}],
                           lr=1e-2)

    lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.2)

    trainer = LTRTrainer(actor, [loader_train, loader_val], optimizer, settings, lr_scheduler)

    trainer.train(40, load_latest=True, fail_safe=True)
Beispiel #9
0
def run(settings):
    # Most common settings are assigned in the settings struct
    settings.device = 'cuda'
    settings.description = 'TransT with default settings.'
    settings.batch_size = 4  # 38
    settings.num_workers = 2  # min(settings.batch_size, 8)
    settings.multi_gpu = False
    settings.print_interval = 1
    settings.normalize_mean = [0.485, 0.456, 0.406]
    settings.normalize_std = [0.229, 0.224, 0.225]
    settings.search_area_factor = 4.0
    settings.template_area_factor = 2.0
    settings.search_feature_sz = 32
    settings.template_feature_sz = 16
    settings.search_sz = settings.search_feature_sz * 8
    settings.temp_sz = settings.template_feature_sz * 8
    settings.center_jitter_factor = {'search': 0, 'template': 0}  # 3
    settings.scale_jitter_factor = {'search': 0., 'template': 0}  # 0.25
    settings.init_ckpt = "pytracking/networks/transt.pth"

    # Transformer
    settings.position_embedding = 'sine'
    settings.hidden_dim = 256
    settings.dropout = 0.1
    settings.nheads = 8
    settings.dim_feedforward = 2048
    settings.featurefusion_layers = 4

    # Train datasets
    # lasot_train = Lasot(settings.env.lasot_dir, split='train')
    # got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')
    got10k_train = Got10k(settings.env.got10k_dir, split='vottrain')
    # trackingnet_train = TrackingNet(settings.env.trackingnet_dir, set_ids=list(range(4)))
    # coco_train = MSCOCOSeq(settings.env.coco_dir)

    # The joint augmentation transform, that is applied to the pairs jointly
    transform_joint = tfm.Transform(tfm.ToGrayscale(probability=0.05))

    # The augmentation transform applied to the training set (individually to each image in the pair)
    transform_train = tfm.Transform(
        tfm.ToTensorAndJitter(0.2),
        tfm.Normalize(mean=settings.normalize_mean,
                      std=settings.normalize_std))

    # Data processing to do on the training pairs
    data_processing_train = processing.TransTProcessing(
        search_area_factor=settings.search_area_factor,
        template_area_factor=settings.template_area_factor,
        search_sz=settings.search_sz,
        temp_sz=settings.temp_sz,
        center_jitter_factor=settings.center_jitter_factor,
        scale_jitter_factor=settings.scale_jitter_factor,
        mode='pair',  # 'sequence',
        transform=transform_train,
        joint_transform=transform_joint)

    # The sampler for training
    dataset_train = sampler.TransTSampler([got10k_train], [1],
                                          samples_per_epoch=1000 *
                                          settings.batch_size,
                                          max_gap=100,
                                          processing=data_processing_train)
    # dataset_train = sampler.TransTSampler([lasot_train, got10k_train, coco_train, trackingnet_train], [1,1,1,1], samples_per_epoch=1000*settings.batch_size, max_gap=100, processing=data_processing_train)

    # The loader for training
    loader_train = LTRLoader('train',
                             dataset_train,
                             training=True,
                             batch_size=settings.batch_size,
                             num_workers=settings.num_workers,
                             shuffle=True,
                             drop_last=True,
                             stack_dim=0)

    # Create network and actor
    model = transt_models.transt_resnet50(settings)

    # Wrap the network for multi GPU training
    if settings.multi_gpu:
        model = MultiGPU(model, dim=0)

    objective = transt_models.transt_loss(settings)
    n_parameters = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('number of params:', n_parameters)

    actor = actors.TranstActor(net=model, objective=objective)

    # Optimizer
    param_dicts = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if "backbone" not in n and p.requires_grad
            ]
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if "backbone" in n and p.requires_grad
            ],
            "lr":
            0.,
        },
    ]
    optimizer = torch.optim.AdamW(param_dicts, lr=0., weight_decay=0.)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 500)

    # Create trainer
    trainer = LTRTrainer(actor, [loader_train], optimizer, settings,
                         lr_scheduler)

    # Run training (set fail_safe=False if you are debugging)
    trainer.train(1000, load_latest=True, fail_safe=True)