Example #1
0
    def _init_pairwise_dataset(dataset_type: DatasetType, dir_path: str,
                               **kwargs) -> SiamesePairwiseDataset:
        if dataset_type == DatasetType.GOT10k:
            data_seq = GOT10k(root_dir=dir_path, **kwargs)
        elif dataset_type == DatasetType.OTB13:
            data_seq = OTB(root_dir=dir_path, version=2013, **kwargs)
        elif dataset_type == DatasetType.OTB15:
            data_seq = OTB(root_dir=dir_path, version=2015, **kwargs)
        elif dataset_type == DatasetType.VOT15:
            data_seq = VOT(dir_path, version=2015, **kwargs)
        elif dataset_type == DatasetType.ILSVRC15:
            data_seq = ImageNetVID(root_dir=dir_path, subset='train', **kwargs)
        else:
            raise ValueError(f"unsupported dataset type: {dataset_type}")

        pairwise_dataset = SiamesePairwiseDataset(cast(Sequence, data_seq),
                                                  TrackerConfig())

        return pairwise_dataset
Example #2
0
def main():
    '''parameter initialization'''
    args = parser.parse_args()
    exp_name_dir = AverageMeter.experiment_name_dir(args.experiment_name)
    '''model on gpu'''
    model = TrackerSiamRPN()
    '''setup train data loader'''
    name = 'All'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset = GOT10k(root_dir, subset='train')
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC2017_VID/ILSVRC'
        seq_dataset = ImageNetVID(root_dir, subset=('train'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC2017_VID/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('train'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='train')
        seq_dataset = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset', len(seq_dataset))

    train_data = TrainDataLoader(seq_dataset, name)
    train_loader = DataLoader(dataset=train_data,
                              batch_size=1,
                              shuffle=True,
                              num_workers=16,
                              pin_memory=True)
    '''setup val data loader'''
    name = 'All'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset_val = GOT10k(root_dir, subset='val')
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC2017_VID/ILSVRC'
        seq_dataset_val = ImageNetVID(root_dir, subset=('val'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC2017_VID/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('val'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='val')
        seq_dataset_val = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset_val', len(seq_dataset_val))

    val_data = TrainDataLoader(seq_dataset_val, name)
    val_loader = DataLoader(dataset=val_data,
                            batch_size=1,
                            shuffle=True,
                            num_workers=16,
                            pin_memory=True)
    '''load weights'''
    init_weights(model)

    if not args.checkpoint_path == None:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        try:
            model.net.load_state_dict(
                torch.load(args.checkpoint_path,
                           map_location=lambda storage, loc: storage))
            print('You are loading the model.load_state_dict')
        except:
            init_weights(model)
    '''train phase'''
    closses, rlosses, tlosses, steps = AverageMeter(), AverageMeter(
    ), AverageMeter(), AverageMeter()

    for epoch in range(config.epoches):
        print('Train epoch {}/{}'.format(epoch + 1, config.epoches))
        with tqdm(total=config.train_epoch_size) as progbar:
            for i, dataset in enumerate(train_loader):

                closs, rloss, loss, cur_lr = model.step(epoch,
                                                        dataset,
                                                        backward=True)

                closs_ = closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                closses.update(closs.cpu().item())
                rlosses.update(rloss.cpu().item())
                tlosses.update(loss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(closses.avg),
                                    rloss='{:05.3f}'.format(rlosses.avg),
                                    tloss='{:05.3f}'.format(tlosses.avg))

                progbar.update()

                if i >= config.train_epoch_size - 1:
                    '''save plot'''
                    closses.closs_array.append(closses.avg)
                    rlosses.rloss_array.append(rlosses.avg)
                    tlosses.loss_array.append(tlosses.avg)
                    steps.update(steps.count)
                    steps.steps_array.append(steps.count)

                    steps.plot(exp_name_dir)
                    '''save model'''
                    model.save(model, exp_name_dir, epoch)

                    break
Example #3
0
 def test_vid(self):
     root_dir = os.path.join(self.data_dir, 'ILSVRC')
     dataset = ImageNetVID(root_dir, subset=('train', 'val'))
     self._check_dataset(dataset)
Example #4
0
def main():
    '''parameter initialization'''
    args = parser.parse_args()
    exp_name_dir = util.experiment_name_dir(args.experiment_name)
    '''model on gpu'''
    model = TrackerSiamRPN()
    model.net.init_weights()
    '''setup train data loader'''
    name = 'VID'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset = GOT10k(root_dir, subset='val')
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC2017_VID'
        seq_dataset = ImageNetVID(root_dir, subset=('train', 'val'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC2017_VID'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('train'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='train')
        seq_dataset = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset', len(seq_dataset))

    train_data = TrainDataLoader(seq_dataset, name)
    train_loader = DataLoader(dataset=train_data,
                              batch_size=64,
                              shuffle=True,
                              num_workers=16,
                              pin_memory=True)
    '''setup val data loader'''
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset_val = GOT10k(root_dir, subset='val')
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC2017_VID'
        seq_dataset_val = ImageNetVID(root_dir, subset=('val'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC2017_VID/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('val'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='val')
        seq_dataset_val = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset_val', len(seq_dataset_val))

    val_data = TrainDataLoader(seq_dataset_val, name)
    val_loader = DataLoader(dataset=val_data,
                            batch_size=8,
                            shuffle=False,
                            num_workers=16,
                            pin_memory=True)
    '''load weights'''

    if not args.checkpoint_path == None:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
        if 'model' in checkpoint.keys():
            model.net.load_state_dict(
                torch.load(args.checkpoint_path, map_location='cpu')['model'])
        else:
            model.net.load_state_dict(
                torch.load(args.checkpoint_path, map_location='cpu'))
        #model.net.load_state_dict(torch.load(args.checkpoint_path, map_location=lambda storage, loc: storage))
        print('You are loading the model.load_state_dict')

    elif config.pretrained_model:
        #print("init with pretrained checkpoint %s" % config.pretrained_model + '\n')
        #print('------------------------------------------------------------------------------------------------ \n')
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.net.state_dict()
        model_dict.update(checkpoint)
        model.net.load_state_dict(model_dict)

    torch.cuda.empty_cache()
    '''train phase'''
    train_closses, train_rlosses, train_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()
    val_closses, val_rlosses, val_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()

    train_val_plot = SavePlot(exp_name_dir, 'train_val_plot')

    for epoch in range(config.epoches):
        model.net.train()
        if config.fix_former_3_layers:
            if 1 > 1:
                util.freeze_layers(model.net.module)
            else:
                util.freeze_layers(model.net)
        print('Train epoch {}/{}'.format(epoch + 1, config.epoches))
        train_loss = []
        with tqdm(total=config.train_epoch_size) as progbar:
            for i, dataset in enumerate(train_loader):

                closs, rloss, loss = model.step(epoch, dataset, train=True)

                closs_ = closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                train_closses.update(closs.cpu().item())
                train_rlosses.update(rloss.cpu().item())
                train_tlosses.update(loss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(train_closses.avg),
                                    rloss='{:05.3f}'.format(train_rlosses.avg),
                                    tloss='{:05.3f}'.format(train_tlosses.avg))

                progbar.update()
                train_loss.append(train_tlosses.avg)

                if i >= config.train_epoch_size - 1:
                    '''save plot'''
                    #train_val_plot.update(train_tlosses.avg, train_label = 'total loss')
                    '''save model'''
                    model.save(model, exp_name_dir, epoch)

                    break

        train_loss = np.mean(train_loss)
        '''val phase'''
        val_loss = []
        with tqdm(total=config.val_epoch_size) as progbar:
            print('Val epoch {}/{}'.format(epoch + 1, config.epoches))
            for i, dataset in enumerate(val_loader):

                val_closs, val_rloss, val_tloss = model.step(epoch,
                                                             dataset,
                                                             train=False)

                closs_ = val_closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                val_closses.update(val_closs.cpu().item())
                val_rlosses.update(val_rloss.cpu().item())
                val_tlosses.update(val_tloss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(val_closses.avg),
                                    rloss='{:05.3f}'.format(val_rlosses.avg),
                                    tloss='{:05.3f}'.format(val_tlosses.avg))

                progbar.update()

                val_loss.append(val_tlosses.avg)

                if i >= config.val_epoch_size - 1:
                    break

        val_loss = np.mean(val_loss)
        train_val_plot.update(train_loss, val_loss)
        print('Train loss: {}, val loss: {}'.format(train_loss, val_loss))
Example #5
0
def main():
    '''parameter initialization'''
    args = parser.parse_args()
    exp_name_dir = util.experiment_name_dir(args.experiment_name)
    '''model on gpu'''
    model = TrackerSiamRPN()
    '''setup train data loader'''
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All', 'RGBT-234']
    if name == 'GOT-10k':
        root_dir_RGBT234 = args.train_path
        root_dir_GTOT = '/home/krautsct/Grayscale-Thermal-Dataset'
        seq_dataset_rgb = GOT10k(root_dir_RGBT234, subset='train_i')
        seq_dataset_i = GOT10k(root_dir_RGBT234,
                               subset='train_i',
                               visible=False)
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC'
        seq_dataset = ImageNetVID(root_dir, subset=('train'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('train'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='train')
        seq_dataset = util.data_split(seq_datasetVID, seq_datasetGOT)
    elif name == 'RGBT-234':
        root_dir = args.train_path
        seq_dataset = RGBTSequence(root_dir, subset='train')
        seq_dataset_val = RGBTSequence(root_dir, subset='val')
    print('seq_dataset', len(seq_dataset_rgb))

    train_z_transforms = transforms.Compose([ToTensor()])
    train_x_transforms = transforms.Compose([ToTensor()])

    train_data_ir = TrainDataLoader_ir(seq_dataset_i, train_z_transforms,
                                       train_x_transforms, name)
    anchors = train_data_ir.anchors
    train_loader_ir = DataLoader(dataset=train_data_ir,
                                 batch_size=config.train_batch_size,
                                 shuffle=True,
                                 num_workers=config.train_num_workers,
                                 pin_memory=True)
    train_data_rgb = TrainDataLoader(seq_dataset_rgb, train_z_transforms,
                                     train_x_transforms, name)
    anchors = train_data_rgb.anchors
    train_loader_rgb = DataLoader(dataset=train_data_rgb,
                                  batch_size=config.train_batch_size,
                                  shuffle=True,
                                  num_workers=config.train_num_workers,
                                  pin_memory=True)
    '''setup val data loader'''
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All', 'RGBT-234']
    if name == 'GOT-10k':
        val_dir = '/home/krautsct/RGB-t-Val'
        seq_dataset_val_rgb = GOT10k(val_dir, subset='train_i')
        seq_dataset_val_ir = GOT10k(val_dir, subset='train_i', visible=False)
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC'
        seq_dataset_val = ImageNetVID(root_dir, subset=('val'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('val'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='val')
        seq_dataset_val = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset_val', len(seq_dataset_val_rgb))

    valid_z_transforms = transforms.Compose([ToTensor()])
    valid_x_transforms = transforms.Compose([ToTensor()])

    val_data = TrainDataLoader_ir(seq_dataset_val_ir, valid_z_transforms,
                                  valid_x_transforms, name)
    val_loader_ir = DataLoader(dataset=val_data,
                               batch_size=config.valid_batch_size,
                               shuffle=False,
                               num_workers=config.valid_num_workers,
                               pin_memory=True)
    val_data_rgb = TrainDataLoader(seq_dataset_val_rgb, valid_z_transforms,
                                   valid_x_transforms, name)
    val_loader_rgb = DataLoader(dataset=val_data_rgb,
                                batch_size=config.valid_batch_size,
                                shuffle=False,
                                num_workers=config.valid_num_workers,
                                pin_memory=True)

    val_losslist = []
    '''load weights'''

    if not args.checkpoint_path == None:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
        if 'model' in checkpoint.keys():
            model.net.load_state_dict(
                torch.load(args.checkpoint_path, map_location='cpu')['model'])
        else:
            model.net.load_state_dict(
                torch.load(args.checkpoint_path, map_location='cpu'))
        torch.cuda.empty_cache()
        print('You are loading the model.load_state_dict')

    elif config.pretrained_model:
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.net.state_dict()
        model_dict.update(checkpoint)
        model.net.load_state_dict(model_dict)
        #torch.cuda.empty_cache()
    '''train phase'''
    train_closses, train_rlosses, train_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()
    val_closses, val_rlosses, val_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()

    #train_val_plot = SavePlot(exp_name_dir, 'train_val_plot')
    val_plot = SavePlotVal(exp_name_dir, 'val_plot')
    for epoch in range(config.epoches):
        model.net.train()
        if config.fix_former_3_layers:
            util.freeze_layers(model.net)
        print('Train epoch {}/{}'.format(epoch + 1, config.epoches))
        train_loss = []
        with tqdm(total=config.train_epoch_size) as progbar:
            for i, (dataset_rgb, dataset_ir) in enumerate(
                    zip(train_loader_rgb, train_loader_ir)):
                #for i, dataset_rgb in enumerate(train_loader_rgb):

                closs, rloss, loss = model.step(epoch,
                                                dataset_rgb,
                                                dataset_ir,
                                                anchors,
                                                epoch,
                                                i,
                                                train=True)

                closs_ = closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                train_closses.update(closs.cpu().item())
                train_rlosses.update(rloss.cpu().item())
                train_tlosses.update(loss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(train_closses.avg),
                                    rloss='{:05.5f}'.format(train_rlosses.avg),
                                    tloss='{:05.3f}'.format(train_tlosses.avg))

                progbar.update()
                train_loss.append(train_tlosses.avg)

                if i >= config.train_epoch_size - 1:
                    '''save model'''
                    model.save(model, exp_name_dir, epoch)

                    break

        train_loss = np.mean(train_loss)
        '''val phase'''
        val_loss = []
        with tqdm(total=config.val_epoch_size) as progbar:
            print('Val epoch {}/{}'.format(epoch + 1, config.epoches))
            for i, (dataset_rgb,
                    dataset_ir) in enumerate(zip(val_loader_rgb,
                                                 val_loader_ir)):
                #for i, dataset_rgb in enumerate(val_loader_rgb):

                val_closs, val_rloss, val_tloss = model.step(epoch,
                                                             dataset_rgb,
                                                             dataset_ir,
                                                             anchors,
                                                             epoch,
                                                             train=False)

                closs_ = val_closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                val_closses.update(val_closs.cpu().item())
                val_rlosses.update(val_rloss.cpu().item())
                val_tlosses.update(val_tloss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(val_closses.avg),
                                    rloss='{:05.5f}'.format(val_rlosses.avg),
                                    tloss='{:05.3f}'.format(val_tlosses.avg))

                progbar.update()

                val_loss.append(val_tlosses.avg)

                if i >= config.val_epoch_size - 1:
                    break

        val_loss = np.mean(val_loss)
        #train_val_plot.update(train_loss, val_loss)
        val_plot.update(val_loss)
        val_losslist.append(val_loss)
        print('Train loss: {}, val loss: {}'.format(train_loss, val_loss))
        record_path = os.path.dirname(exp_name_dir)
        if not os.path.isdir(record_path):
            os.makedirs(record_path)
        record_file = os.path.join(exp_name_dir, 'val_losses.txt')
        np.savetxt(record_file, val_losslist, fmt='%.3f', delimiter=',')
Example #6
0
from torchvision import models
from got10k.datasets import ImageNetVID, GOT10k
from pairwise_cf import Pairwise
from siamfc_VGG_cf import TrackerSiamFC
from got10k_tmp.experiments import *

if __name__ == '__main__':
    # setup dataset
    name = 'VID'
    assert name in ['VID', 'GOT-10k']
    if name == 'GOT-10k':
        root_dir = 'data/GOT-10k'
        seq_dataset = GOT10k(root_dir, subset='train')
    elif name == 'VID':
        root_dir = '/home/user/ILSVRC2015'
        seq_dataset = ImageNetVID(root_dir, subset=('train', 'val'))
    pair_dataset = Pairwise(seq_dataset)

    # setup data loader
    cuda = torch.cuda.is_available()
    loader = DataLoader(pair_dataset,
                        batch_size=8,
                        shuffle=True,
                        pin_memory=cuda,
                        drop_last=True,
                        num_workers=4)

    # setup tracker
    tracker = TrackerSiamFC()

    #pretrained vgg
Example #7
0
#from pairwise import Pairwise
from siamfc import TrackerSiamFC
from got10k.experiments import *

from config import config

if __name__ == '__main__':

    # setup dataset
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        seq_dataset = GOT10k(config.root_dir_for_GOT_10k, subset='train')
        pair_dataset = Pairwise(seq_dataset)
    elif name == 'VID':
        seq_dataset = ImageNetVID(config.root_dir_for_VID, subset=('train', 'val'))
        pair_dataset = Pairwise(seq_dataset)
    elif name == 'All':
        seq_got_dataset = GOT10k(config.root_dir_for_GOT_10k, subset='train')
        seq_vid_dataset = ImageNetVID(config.root_dir_for_VID, subset=('train', 'val'))
        pair_dataset = Pairwise(seq_got_dataset) + Pairwise(seq_vid_dataset)

    print(len(pair_dataset))

    # setup data loader
    cuda = torch.cuda.is_available()
    loader = DataLoader(pair_dataset,
                        batch_size = config.batch_size,
                        shuffle    = True,
                        pin_memory = cuda,
                        drop_last  = True,
Example #8
0
from siamfc import TrackerSiamFC
from got10k.experiments import *
import numpy as np

from config import config

if __name__ == '__main__':

    # setup the desired dataset for training
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        seq_dataset = GOT10k(config.GOT_10k_dataset_directory, subset='train')
        pair_dataset = Pairwise(seq_dataset)
    elif name == 'VID':
        seq_dataset = ImageNetVID(config.Imagenet_dataset_directory,
                                  subset=('train', 'val'))
        pair_dataset = Pairwise(seq_dataset)
    elif name == 'All':
        seq_got_dataset = GOT10k(config.GOT_10k_dataset_directory,
                                 subset='train')
        seq_vid_dataset = ImageNetVID(config.Imagenet_dataset_directory,
                                      subset=('train', 'val'))
        pair_dataset = Pairwise(seq_got_dataset) + Pairwise(seq_vid_dataset)

    print(len(pair_dataset))

    # setup the data loader
    cuda = torch.cuda.is_available()
    loader = DataLoader(pair_dataset,
                        batch_size=config.batch_size,
                        shuffle=True,
Example #9
0
def main():
    '''parameter initialization'''
    args = parser.parse_args()
    exp_name_dir = util.experiment_name_dir(args.experiment_name)
    '''model on gpu'''
    model = TrackerSiamRPN()
    '''setup train data loader'''
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset = GOT10k(root_dir, subset='train')
    elif name == 'VID':
        root_dir = '/store_ssd/ILSVRC'
        seq_dataset = ImageNetVID(root_dir, subset=('train'))
    elif name == 'All':
        root_dir_vid = '/store_ssd/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('train'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='train')
        seq_dataset = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset', len(seq_dataset))

    train_z_transforms = transforms.Compose([ToTensor()])
    train_x_transforms = transforms.Compose([
        RandomCrop([config.detection_img_size, config.detection_img_size],
                   config.max_translate),
        RandomScale(config.scale_resize),
        ToTensor()
    ])

    train_data = TrainDataLoader(seq_dataset, train_z_transforms,
                                 train_x_transforms, name)
    anchors = train_data.anchors
    train_loader = DataLoader(dataset=train_data,
                              batch_size=config.train_batch_size,
                              shuffle=True,
                              num_workers=config.train_num_workers,
                              pin_memory=True)
    '''setup val data loader'''
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset_val = GOT10k(root_dir, subset='val')
    elif name == 'VID':
        root_dir = '/store_ssd/ILSVRC'
        seq_dataset_val = ImageNetVID(root_dir, subset=('val'))
    elif name == 'All':
        root_dir_vid = '/store_ssd/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('val'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='val')
        seq_dataset_val = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset_val', len(seq_dataset_val))

    valid_z_transforms = transforms.Compose([ToTensor()])
    valid_x_transforms = transforms.Compose([ToTensor()])

    val_data = TrainDataLoader(seq_dataset_val, valid_z_transforms,
                               valid_x_transforms, name)
    val_loader = DataLoader(dataset=val_data,
                            batch_size=config.valid_batch_size,
                            shuffle=False,
                            num_workers=config.valid_num_workers,
                            pin_memory=True)
    '''load weights'''

    if not args.checkpoint_path == None and args.epoch_i > 0:
        checkpoint_path = os.path.join(args.checkpoint_path,
                                       'model_e{}.pth'.format(args.epoch_i))
        assert os.path.isfile(
            checkpoint_path), '{} is not valid checkpoint_path'.format(
                checkpoint_path)

        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        if 'model' in checkpoint.keys():
            model.net.load_state_dict(
                torch.load(checkpoint_path, map_location='cpu')['model'])
        else:
            model.net.load_state_dict(
                torch.load(checkpoint_path, map_location='cpu'))
        torch.cuda.empty_cache()
        print('You are loading the model.load_state_dict')

    elif config.pretrained_model:
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.net.state_dict()
        model_dict.update(checkpoint)
        model.net.load_state_dict(model_dict)
        #torch.cuda.empty_cache()
        print('You are loading the pretrained model')
    '''train phase'''
    train_closses, train_rlosses, train_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()
    val_closses, val_rlosses, val_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()

    train_val_plot = SavePlot(exp_name_dir, 'train_val_plot')
    model.adjust_lr(args.epoch_i)

    for epoch in range(args.epoch_i, config.epoches):
        model.net.train()
        if config.fix_former_3_layers:
            util.freeze_layers(model.net)
        print('Train epoch {}/{}'.format(epoch + 1, config.epoches))
        train_loss = []
        with tqdm(total=config.train_epoch_size) as progbar:
            for i, dataset in enumerate(train_loader):

                closs, rloss, loss = model.step(epoch,
                                                dataset,
                                                anchors,
                                                i,
                                                train=True)

                closs_ = closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                train_closses.update(closs.cpu().item())
                train_rlosses.update(rloss.cpu().item())
                train_tlosses.update(loss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(train_closses.avg),
                                    rloss='{:05.5f}'.format(train_rlosses.avg),
                                    tloss='{:05.3f}'.format(train_tlosses.avg))

                progbar.update()
                train_loss.append(train_tlosses.avg)

                if i >= config.train_epoch_size - 1:
                    '''save model'''
                    model.save(model, exp_name_dir, epoch)

                    break

        train_loss = np.mean(train_loss)
        '''val phase'''
        val_loss = []
        with tqdm(total=config.val_epoch_size) as progbar:
            print('Val epoch {}/{}'.format(epoch + 1, config.epoches))
            for i, dataset in enumerate(val_loader):

                val_closs, val_rloss, val_tloss = model.step(epoch,
                                                             dataset,
                                                             anchors,
                                                             train=False)

                closs_ = val_closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                val_closses.update(val_closs.cpu().item())
                val_rlosses.update(val_rloss.cpu().item())
                val_tlosses.update(val_tloss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(val_closses.avg),
                                    rloss='{:05.5f}'.format(val_rlosses.avg),
                                    tloss='{:05.3f}'.format(val_tlosses.avg))

                progbar.update()

                val_loss.append(val_tlosses.avg)

                if i >= config.val_epoch_size - 1:
                    break

        val_loss = np.mean(val_loss)
        train_val_plot.update(train_loss, val_loss)
        print('Train loss: {}, val loss: {}'.format(train_loss, val_loss))