示例#1
0
def main():
    '''parameter initialization'''
    args = parser.parse_args()
    exp_name_dir = AverageMeter.experiment_name_dir(args.experiment_name)
    '''model on gpu'''
    model = TrackerSiamRPN()
    '''setup train data loader'''
    name = 'All'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset = GOT10k(root_dir, subset='train')
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC2017_VID/ILSVRC'
        seq_dataset = ImageNetVID(root_dir, subset=('train'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC2017_VID/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('train'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='train')
        seq_dataset = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset', len(seq_dataset))

    train_data = TrainDataLoader(seq_dataset, name)
    train_loader = DataLoader(dataset=train_data,
                              batch_size=1,
                              shuffle=True,
                              num_workers=16,
                              pin_memory=True)
    '''setup val data loader'''
    name = 'All'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset_val = GOT10k(root_dir, subset='val')
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC2017_VID/ILSVRC'
        seq_dataset_val = ImageNetVID(root_dir, subset=('val'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC2017_VID/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('val'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='val')
        seq_dataset_val = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset_val', len(seq_dataset_val))

    val_data = TrainDataLoader(seq_dataset_val, name)
    val_loader = DataLoader(dataset=val_data,
                            batch_size=1,
                            shuffle=True,
                            num_workers=16,
                            pin_memory=True)
    '''load weights'''
    init_weights(model)

    if not args.checkpoint_path == None:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        try:
            model.net.load_state_dict(
                torch.load(args.checkpoint_path,
                           map_location=lambda storage, loc: storage))
            print('You are loading the model.load_state_dict')
        except:
            init_weights(model)
    '''train phase'''
    closses, rlosses, tlosses, steps = AverageMeter(), AverageMeter(
    ), AverageMeter(), AverageMeter()

    for epoch in range(config.epoches):
        print('Train epoch {}/{}'.format(epoch + 1, config.epoches))
        with tqdm(total=config.train_epoch_size) as progbar:
            for i, dataset in enumerate(train_loader):

                closs, rloss, loss, cur_lr = model.step(epoch,
                                                        dataset,
                                                        backward=True)

                closs_ = closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                closses.update(closs.cpu().item())
                rlosses.update(rloss.cpu().item())
                tlosses.update(loss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(closses.avg),
                                    rloss='{:05.3f}'.format(rlosses.avg),
                                    tloss='{:05.3f}'.format(tlosses.avg))

                progbar.update()

                if i >= config.train_epoch_size - 1:
                    '''save plot'''
                    closses.closs_array.append(closses.avg)
                    rlosses.rloss_array.append(rlosses.avg)
                    tlosses.loss_array.append(tlosses.avg)
                    steps.update(steps.count)
                    steps.steps_array.append(steps.count)

                    steps.plot(exp_name_dir)
                    '''save model'''
                    model.save(model, exp_name_dir, epoch)

                    break
示例#2
0
def main():
    '''parameter initialization'''
    args = parser.parse_args()
    exp_name_dir = util.experiment_name_dir(args.experiment_name)
    '''model on gpu'''
    model = TrackerSiamRPN()
    model.net.init_weights()
    '''setup train data loader'''
    name = 'VID'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset = GOT10k(root_dir, subset='val')
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC2017_VID'
        seq_dataset = ImageNetVID(root_dir, subset=('train', 'val'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC2017_VID'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('train'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='train')
        seq_dataset = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset', len(seq_dataset))

    train_data = TrainDataLoader(seq_dataset, name)
    train_loader = DataLoader(dataset=train_data,
                              batch_size=64,
                              shuffle=True,
                              num_workers=16,
                              pin_memory=True)
    '''setup val data loader'''
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset_val = GOT10k(root_dir, subset='val')
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC2017_VID'
        seq_dataset_val = ImageNetVID(root_dir, subset=('val'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC2017_VID/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('val'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='val')
        seq_dataset_val = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset_val', len(seq_dataset_val))

    val_data = TrainDataLoader(seq_dataset_val, name)
    val_loader = DataLoader(dataset=val_data,
                            batch_size=8,
                            shuffle=False,
                            num_workers=16,
                            pin_memory=True)
    '''load weights'''

    if not args.checkpoint_path == None:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
        if 'model' in checkpoint.keys():
            model.net.load_state_dict(
                torch.load(args.checkpoint_path, map_location='cpu')['model'])
        else:
            model.net.load_state_dict(
                torch.load(args.checkpoint_path, map_location='cpu'))
        #model.net.load_state_dict(torch.load(args.checkpoint_path, map_location=lambda storage, loc: storage))
        print('You are loading the model.load_state_dict')

    elif config.pretrained_model:
        #print("init with pretrained checkpoint %s" % config.pretrained_model + '\n')
        #print('------------------------------------------------------------------------------------------------ \n')
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.net.state_dict()
        model_dict.update(checkpoint)
        model.net.load_state_dict(model_dict)

    torch.cuda.empty_cache()
    '''train phase'''
    train_closses, train_rlosses, train_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()
    val_closses, val_rlosses, val_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()

    train_val_plot = SavePlot(exp_name_dir, 'train_val_plot')

    for epoch in range(config.epoches):
        model.net.train()
        if config.fix_former_3_layers:
            if 1 > 1:
                util.freeze_layers(model.net.module)
            else:
                util.freeze_layers(model.net)
        print('Train epoch {}/{}'.format(epoch + 1, config.epoches))
        train_loss = []
        with tqdm(total=config.train_epoch_size) as progbar:
            for i, dataset in enumerate(train_loader):

                closs, rloss, loss = model.step(epoch, dataset, train=True)

                closs_ = closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                train_closses.update(closs.cpu().item())
                train_rlosses.update(rloss.cpu().item())
                train_tlosses.update(loss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(train_closses.avg),
                                    rloss='{:05.3f}'.format(train_rlosses.avg),
                                    tloss='{:05.3f}'.format(train_tlosses.avg))

                progbar.update()
                train_loss.append(train_tlosses.avg)

                if i >= config.train_epoch_size - 1:
                    '''save plot'''
                    #train_val_plot.update(train_tlosses.avg, train_label = 'total loss')
                    '''save model'''
                    model.save(model, exp_name_dir, epoch)

                    break

        train_loss = np.mean(train_loss)
        '''val phase'''
        val_loss = []
        with tqdm(total=config.val_epoch_size) as progbar:
            print('Val epoch {}/{}'.format(epoch + 1, config.epoches))
            for i, dataset in enumerate(val_loader):

                val_closs, val_rloss, val_tloss = model.step(epoch,
                                                             dataset,
                                                             train=False)

                closs_ = val_closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                val_closses.update(val_closs.cpu().item())
                val_rlosses.update(val_rloss.cpu().item())
                val_tlosses.update(val_tloss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(val_closses.avg),
                                    rloss='{:05.3f}'.format(val_rlosses.avg),
                                    tloss='{:05.3f}'.format(val_tlosses.avg))

                progbar.update()

                val_loss.append(val_tlosses.avg)

                if i >= config.val_epoch_size - 1:
                    break

        val_loss = np.mean(val_loss)
        train_val_plot.update(train_loss, val_loss)
        print('Train loss: {}, val loss: {}'.format(train_loss, val_loss))
def main():
    """parameter initialization"""
    args = parser.parse_args()
    exp_name_dir = experiment_name_dir(args.experiment_name)
    """Load the parameters from json file"""
    json_path = os.path.join(exp_name_dir, 'parameters.json')
    assert os.path.isfile(json_path), (
        "No json configuration file found at {}".format(json_path))
    with open(json_path) as data_file:
        params = json.load(data_file)
    """ train dataloader """
    data_loader = TrainDataLoader(args.train_path)
    """ compute max_batches """
    for root, dirs, files in os.walk(args.train_path):
        for dirname in dirs:
            dir_path = os.path.join(root, dirname)
            args.max_batches += len(os.listdir(dir_path))
    """ Model on gpu """
    model = TrackerSiamRPN(params)
    #model = model.cuda()
    cudnn.benchmark = True
    """ load weights """
    init_weights(model)
    if not args.checkpoint_path == None:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        try:
            checkpoint = torch.load(args.checkpoint_path)
            start = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        except:
            start = 0
            init_weights(model)
    else:
        start = 0
    """ train phase """
    closses, rlosses, tlosses = AverageMeter(), AverageMeter(), AverageMeter()
    steps = 0
    for epoch in range(start, args.max_epoches):
        #cur_lr = adjust_learning_rate(params["lr"], optimizer, epoch, gamma=0.1)
        index_list = range(data_loader.__len__())
        for example in tqdm(range(1000)):  # args.max_batches
            ret = data_loader.__get__(random.choice(index_list))

            closs, rloss, loss, reg_pred, reg_target, pos_index, neg_index, cur_lr = model.step(
                ret, epoch, backward=True)

            closs_ = closs.cpu().item()

            if np.isnan(closs_):
                sys.exit(0)

            closses.update(closs.cpu().item())
            rlosses.update(rloss.cpu().item())
            tlosses.update(loss.cpu().item())
            steps += 1

            if example % 1000 == 0:
                print(
                    "Epoch:{:04d}\texample:{:06d}/{:06d}({:.2f})%\tlr:{:.7f}\tcloss:{:.4f}\trloss:{:.4f}\ttloss:{:.4f}"
                    .format((epoch + 1), steps, args.max_batches,
                            100 * (steps) / args.max_batches, cur_lr,
                            closses.avg, rlosses.avg, tlosses.avg))
        """save model"""
        model_save_dir_pth = '{}/model'.format(exp_name_dir)
        if not os.path.exists(model_save_dir_pth):
            os.makedirs(model_save_dir_pth)
        net_path = os.path.join(model_save_dir_pth,
                                'model_e%d.pth' % (epoch + 1))
        torch.save(model.net.state_dict(), net_path)
示例#4
0
def main():
    '''parameter initialization'''
    args = parser.parse_args()
    exp_name_dir = util.experiment_name_dir(args.experiment_name)
    '''model on gpu'''
    model = TrackerSiamRPN()
    '''setup train data loader'''
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All', 'RGBT-234']
    if name == 'GOT-10k':
        root_dir_RGBT234 = args.train_path
        root_dir_GTOT = '/home/krautsct/Grayscale-Thermal-Dataset'
        seq_dataset_rgb = GOT10k(root_dir_RGBT234, subset='train_i')
        #seq_dataset_i = GOT10k(root_dir_RGBT234, root_dir_GTOT, subset='train', visible=False)
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC'
        seq_dataset = ImageNetVID(root_dir, subset=('train'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('train'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='train')
        seq_dataset = util.data_split(seq_datasetVID, seq_datasetGOT)
    elif name == 'RGBT-234':
        root_dir = args.train_path
        seq_dataset = RGBTSequence(root_dir, subset='train')
        seq_dataset_val = RGBTSequence(root_dir, subset='val')
    print('seq_dataset', len(seq_dataset_rgb))

    train_z_transforms = transforms.Compose([ToTensor()])
    train_x_transforms = transforms.Compose([ToTensor()])
    '''train_data_ir  = TrainDataLoader_ir(seq_dataset_i, train_z_transforms, train_x_transforms, name)
    anchors = train_data_ir.anchors
    train_loader_ir = DataLoader(  dataset    = train_data_ir,
                                batch_size = config.train_batch_size,
                                shuffle    = True,
                                num_workers= config.train_num_workers,
                                pin_memory = True)'''
    train_data_rgb = TrainDataLoader(seq_dataset_rgb, train_z_transforms,
                                     train_x_transforms, name)
    anchors = train_data_rgb.anchors
    train_loader_rgb = DataLoader(dataset=train_data_rgb,
                                  batch_size=config.train_batch_size,
                                  shuffle=True,
                                  num_workers=config.train_num_workers,
                                  pin_memory=True)
    '''setup val data loader'''
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All', 'RGBT-234']
    if name == 'GOT-10k':
        val_dir = '/home/krautsct/RGB-t-Val'
        seq_dataset_val_rgb = GOT10k(val_dir, subset='train_i')
        #seq_dataset_val_ir = GOT10k(val_dir, subset='train_i', visible=False)
    elif name == 'VID':
        root_dir = '/home/arbi/desktop/ILSVRC'
        seq_dataset_val = ImageNetVID(root_dir, subset=('val'))
    elif name == 'All':
        root_dir_vid = '/home/arbi/desktop/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('val'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='val')
        seq_dataset_val = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset_val', len(seq_dataset_val_rgb))

    valid_z_transforms = transforms.Compose([ToTensor()])
    valid_x_transforms = transforms.Compose([ToTensor()])
    '''val_data  = TrainDataLoader_ir(seq_dataset_val_ir, valid_z_transforms, valid_x_transforms, name)
    val_loader_ir = DataLoader(    dataset    = val_data,
                                batch_size = config.valid_batch_size,
                                shuffle    = False,
                                num_workers= config.valid_num_workers,
                                pin_memory = True)'''
    val_data_rgb = TrainDataLoader(seq_dataset_val_rgb, valid_z_transforms,
                                   valid_x_transforms, name)
    val_loader_rgb = DataLoader(dataset=val_data_rgb,
                                batch_size=config.valid_batch_size,
                                shuffle=False,
                                num_workers=config.valid_num_workers,
                                pin_memory=True)

    val_losslist = []
    '''load weights'''

    if not args.checkpoint_path == None:
        assert os.path.isfile(
            args.checkpoint_path), '{} is not valid checkpoint_path'.format(
                args.checkpoint_path)
        checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
        if 'model' in checkpoint.keys():
            model.net.load_state_dict(
                torch.load(args.checkpoint_path, map_location='cpu')['model'])
        else:
            model.net.load_state_dict(
                torch.load(args.checkpoint_path, map_location='cpu'))
        torch.cuda.empty_cache()
        print('You are loading the model.load_state_dict')

    elif config.pretrained_model:
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.net.state_dict()
        model_dict.update(checkpoint)
        model.net.load_state_dict(model_dict)
        #torch.cuda.empty_cache()
    '''train phase'''
    train_closses, train_rlosses, train_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()
    val_closses, val_rlosses, val_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()

    #train_val_plot = SavePlot(exp_name_dir, 'train_val_plot')
    val_plot = SavePlotVal(exp_name_dir, 'val_plot')
    for epoch in range(config.epoches):
        model.net.train()
        if config.fix_former_3_layers:
            util.freeze_layers(model.net)
        print('Train epoch {}/{}'.format(epoch + 1, config.epoches))
        train_loss = []
        with tqdm(total=config.train_epoch_size) as progbar:
            #for i, (dataset_rgb, dataset_ir) in enumerate(zip(train_loader_rgb, train_loader_ir)):
            for i, dataset_rgb in enumerate(train_loader_rgb):

                closs, rloss, loss = model.step(epoch,
                                                dataset_rgb,
                                                anchors,
                                                epoch,
                                                i,
                                                train=True)  # dataset_ir,

                closs_ = closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                train_closses.update(closs.cpu().item())
                train_rlosses.update(rloss.cpu().item())
                train_tlosses.update(loss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(train_closses.avg),
                                    rloss='{:05.5f}'.format(train_rlosses.avg),
                                    tloss='{:05.3f}'.format(train_tlosses.avg))

                progbar.update()
                train_loss.append(train_tlosses.avg)

                if i >= config.train_epoch_size - 1:
                    '''save model'''
                    model.save(model, exp_name_dir, epoch)

                    break

        train_loss = np.mean(train_loss)
        '''val phase'''
        val_loss = []
        with tqdm(total=config.val_epoch_size) as progbar:
            print('Val epoch {}/{}'.format(epoch + 1, config.epoches))
            #for i, (dataset_rgb, dataset_ir) in enumerate(zip(val_loader_rgb, val_loader_ir)):
            for i, dataset_rgb in enumerate(val_loader_rgb):

                val_closs, val_rloss, val_tloss = model.step(
                    epoch, dataset_rgb, anchors, epoch,
                    train=False)  # dataset_ir,

                closs_ = val_closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                val_closses.update(val_closs.cpu().item())
                val_rlosses.update(val_rloss.cpu().item())
                val_tlosses.update(val_tloss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(val_closses.avg),
                                    rloss='{:05.5f}'.format(val_rlosses.avg),
                                    tloss='{:05.3f}'.format(val_tlosses.avg))

                progbar.update()

                val_loss.append(val_tlosses.avg)

                if i >= config.val_epoch_size - 1:
                    break

        val_loss = np.mean(val_loss)
        #train_val_plot.update(train_loss, val_loss)
        val_plot.update(val_loss)
        val_losslist.append(val_loss)
        print('Train loss: {}, val loss: {}'.format(train_loss, val_loss))
        record_path = os.path.dirname(exp_name_dir)
        if not os.path.isdir(record_path):
            os.makedirs(record_path)
        record_file = os.path.join(exp_name_dir, 'val_losses.txt')
        np.savetxt(record_file, val_losslist, fmt='%.3f', delimiter=',')
示例#5
0
from got10k.experiments import *

from net import TrackerSiamRPN


if __name__ == '__main__':

    '''setup tracker'''
    net_path = '../train/experiments/default/model/model_e6.pth'
    tracker = TrackerSiamRPN(net_path=net_path)

    '''setup experiments'''
    experiments = ExperimentGOT10k('/Users/arbi/Desktop', subset='val', result_dir='results', report_dir='reports')
    #experiments = ExperimentOTB('data/OTB', version=2015, result_dir='resultsOTB', report_dir='reportsOTB')
    #experiments = ExperimentVOT('../data/vot2018', version=2018, result_dir='../results_two', report_dir='../reports_two')

    '''run tracking experiments and report performance'''
    experiments.run(tracker, visualize=True)
    experiments.report([tracker.name])
示例#6
0
def main():
    '''parameter initialization'''
    args = parser.parse_args()
    exp_name_dir = util.experiment_name_dir(args.experiment_name)
    '''model on gpu'''
    model = TrackerSiamRPN()
    '''setup train data loader'''
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset = GOT10k(root_dir, subset='train')
    elif name == 'VID':
        root_dir = '/store_ssd/ILSVRC'
        seq_dataset = ImageNetVID(root_dir, subset=('train'))
    elif name == 'All':
        root_dir_vid = '/store_ssd/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('train'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='train')
        seq_dataset = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset', len(seq_dataset))

    train_z_transforms = transforms.Compose([ToTensor()])
    train_x_transforms = transforms.Compose([
        RandomCrop([config.detection_img_size, config.detection_img_size],
                   config.max_translate),
        RandomScale(config.scale_resize),
        ToTensor()
    ])

    train_data = TrainDataLoader(seq_dataset, train_z_transforms,
                                 train_x_transforms, name)
    anchors = train_data.anchors
    train_loader = DataLoader(dataset=train_data,
                              batch_size=config.train_batch_size,
                              shuffle=True,
                              num_workers=config.train_num_workers,
                              pin_memory=True)
    '''setup val data loader'''
    name = 'GOT-10k'
    assert name in ['VID', 'GOT-10k', 'All']
    if name == 'GOT-10k':
        root_dir = args.train_path
        seq_dataset_val = GOT10k(root_dir, subset='val')
    elif name == 'VID':
        root_dir = '/store_ssd/ILSVRC'
        seq_dataset_val = ImageNetVID(root_dir, subset=('val'))
    elif name == 'All':
        root_dir_vid = '/store_ssd/ILSVRC'
        seq_datasetVID = ImageNetVID(root_dir_vid, subset=('val'))
        root_dir_got = args.train_path
        seq_datasetGOT = GOT10k(root_dir_got, subset='val')
        seq_dataset_val = util.data_split(seq_datasetVID, seq_datasetGOT)
    print('seq_dataset_val', len(seq_dataset_val))

    valid_z_transforms = transforms.Compose([ToTensor()])
    valid_x_transforms = transforms.Compose([ToTensor()])

    val_data = TrainDataLoader(seq_dataset_val, valid_z_transforms,
                               valid_x_transforms, name)
    val_loader = DataLoader(dataset=val_data,
                            batch_size=config.valid_batch_size,
                            shuffle=False,
                            num_workers=config.valid_num_workers,
                            pin_memory=True)
    '''load weights'''

    if not args.checkpoint_path == None and args.epoch_i > 0:
        checkpoint_path = os.path.join(args.checkpoint_path,
                                       'model_e{}.pth'.format(args.epoch_i))
        assert os.path.isfile(
            checkpoint_path), '{} is not valid checkpoint_path'.format(
                checkpoint_path)

        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        if 'model' in checkpoint.keys():
            model.net.load_state_dict(
                torch.load(checkpoint_path, map_location='cpu')['model'])
        else:
            model.net.load_state_dict(
                torch.load(checkpoint_path, map_location='cpu'))
        torch.cuda.empty_cache()
        print('You are loading the model.load_state_dict')

    elif config.pretrained_model:
        checkpoint = torch.load(config.pretrained_model)
        # change name and load parameters
        checkpoint = {
            k.replace('features.features', 'featureExtract'): v
            for k, v in checkpoint.items()
        }
        model_dict = model.net.state_dict()
        model_dict.update(checkpoint)
        model.net.load_state_dict(model_dict)
        #torch.cuda.empty_cache()
        print('You are loading the pretrained model')
    '''train phase'''
    train_closses, train_rlosses, train_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()
    val_closses, val_rlosses, val_tlosses = AverageMeter(), AverageMeter(
    ), AverageMeter()

    train_val_plot = SavePlot(exp_name_dir, 'train_val_plot')
    model.adjust_lr(args.epoch_i)

    for epoch in range(args.epoch_i, config.epoches):
        model.net.train()
        if config.fix_former_3_layers:
            util.freeze_layers(model.net)
        print('Train epoch {}/{}'.format(epoch + 1, config.epoches))
        train_loss = []
        with tqdm(total=config.train_epoch_size) as progbar:
            for i, dataset in enumerate(train_loader):

                closs, rloss, loss = model.step(epoch,
                                                dataset,
                                                anchors,
                                                i,
                                                train=True)

                closs_ = closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                train_closses.update(closs.cpu().item())
                train_rlosses.update(rloss.cpu().item())
                train_tlosses.update(loss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(train_closses.avg),
                                    rloss='{:05.5f}'.format(train_rlosses.avg),
                                    tloss='{:05.3f}'.format(train_tlosses.avg))

                progbar.update()
                train_loss.append(train_tlosses.avg)

                if i >= config.train_epoch_size - 1:
                    '''save model'''
                    model.save(model, exp_name_dir, epoch)

                    break

        train_loss = np.mean(train_loss)
        '''val phase'''
        val_loss = []
        with tqdm(total=config.val_epoch_size) as progbar:
            print('Val epoch {}/{}'.format(epoch + 1, config.epoches))
            for i, dataset in enumerate(val_loader):

                val_closs, val_rloss, val_tloss = model.step(epoch,
                                                             dataset,
                                                             anchors,
                                                             train=False)

                closs_ = val_closs.cpu().item()

                if np.isnan(closs_):
                    sys.exit(0)

                val_closses.update(val_closs.cpu().item())
                val_rlosses.update(val_rloss.cpu().item())
                val_tlosses.update(val_tloss.cpu().item())

                progbar.set_postfix(closs='{:05.3f}'.format(val_closses.avg),
                                    rloss='{:05.5f}'.format(val_rlosses.avg),
                                    tloss='{:05.3f}'.format(val_tlosses.avg))

                progbar.update()

                val_loss.append(val_tlosses.avg)

                if i >= config.val_epoch_size - 1:
                    break

        val_loss = np.mean(val_loss)
        train_val_plot.update(train_loss, val_loss)
        print('Train loss: {}, val loss: {}'.format(train_loss, val_loss))