def classifier(args, clf, train, pretrained_dir="/"):
    """Set classifier

    Arguments:
        dataset -- "Cifar10" or "Cifar100"
        clf -- "resnet18", "resnet50", or "resnet101"
        train {bool} -- Train or not

    Keyword Arguments:
        pretrained_dir {str} -- pretrained weights path (default: {"/"})

    Returns:
        Model
    """
    num_classes = args.num_classes
    input_channels = args.image_channels

    map_location = (lambda s, _: s)
    checkpoint_dir = os.path.join(pretrained_dir, args.dataset, clf + '.pth')

    if clf.lower() == 'resnet18':
        if train:
            net = resnet18(num_classes, input_channels)
        else:
            checkpoint = torch.load(checkpoint_dir, map_location=map_location)
            net = resnet18(num_classes, input_channels)
            net.load_state_dict(checkpoint['model_state_dict'])

    elif clf.lower() == 'resnet50':
        if train:
            net = resnet50(num_classes, input_channels)
        else:
            checkpoint = torch.load(checkpoint_dir, map_location=map_location)
            net = resnet50(num_classes, input_channels)
            net.load_state_dict(checkpoint['model_state_dict'])

    elif clf.lower() == 'resnet101':
        if train:
            net = resnet101(num_classes, input_channels)
        else:
            checkpoint = torch.load(checkpoint_dir, map_location=map_location)
            net = resnet101(num_classes, input_channels)
            net.load_state_dict(checkpoint['model_state_dict'])

    else:
        raise Exception(
            "You can choose the model among [resnet18, resnet 50, resnet101]")

    return net
Пример #2
0
def model_init(model_name):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if model_name == 'retinanet' :
      #weight_file_path = '/content/retinanet/resnet34-333f7ec4.pth'
      #weight_file_path = '/content/retinanet/CP_epoch5.pth'
      weight_file_path = '/content/retinanet/retinanet50_pretrained.pth'

    total_keys = len(list(torch.load(weight_file_path).keys()))

    # Create the model
    if total_keys >= 102 and total_keys < 182 :
        retinanet = model.resnet18(num_classes=num_classes, pretrained=False)

    elif total_keys >= 182 and total_keys < 267:
        retinanet = model.resnet34(num_classes=num_classes, pretrained=False)
        
    elif total_keys >= 267 and total_keys < 522:
        retinanet = model.resnet50(num_classes=num_classes, pretrained=False)
        
    elif total_keys >= 522 and total_keys < 777:
        retinanet = model.resnet101(num_classes=num_classes, pretrained=False)
        
    elif total_keys >= 777:
        retinanet = model.resnet152(num_classes=num_classes, pretrained=False)
        
    else:
        raise ValueError('Unsupported model backbone, must be one of resnet18, resnet34, resnet50, resnet101, resnet152')

    retinanet.load_state_dict(torch.load(weight_file_path, map_location=device), strict=False) # Initialisng Model with loaded weights
    print('model initialized..')

    return retinanet, device
Пример #3
0
    def __init__(self, config):
        if not isinstance(config, str):
            self.config = config
        else:
            assert os.path.exists(config)
            self.config = json.load(open(config))
        assert 'name' in self.config
        assert 'data_path' in self.config
        assert 'balanced' in self.config
        assert 'num_samples' in self.config
        os.environ["CUDA_VISIBLE_DEVICES"] = self.config['gpu_ids']

        self.date = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())

        # create project folder
        if not os.path.exists(os.path.join('./results/', self.config['name'])):
            os.makedirs(os.path.join('./results/', self.config['name']))

        if self.config['model']['name'] == 'resnet50':
            self.model = resnet50(
                pretrained=True,
                num_classes=self.config['model']['num_classes'])
        elif self.config['model']['name'] == 'resnet101':
            self.model = resnet101(
                pretrained=True,
                num_classes=self.config['model']['num_classes'])
        elif self.config['model']['name'] == 'inception_v3':
            self.model = inception_v3(
                pretrained=True,
                num_classes=self.config['model']['num_classes'])
        elif self.config['model']['name'] == 'vgg16':
            self.model = vgg16(pretrained=True,
                               num_classes=self.config['model']['num_classes'])
        elif self.config['model']['name'] == 'vgg19_bn':
            self.model = vgg19_bn(
                pretrained=True,
                num_classes=self.config['model']['num_classes'])
        elif self.config['model']['name'] == 'ws_dan_resnet50':
            self.model = ws_dan_resnet50(
                pretrained=True,
                num_classes=self.config['model']['num_classes'],
                num_attentions=self.config['model']['num_attentions'])
        self.model.cuda()
        if len(self.config['gpu_ids']) > 1:
            self.model = nn.DataParallel(self.model)
        #self.criterion = nn.CrossEntropyLoss()
        self.criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(
            [1.0, 3.0]).cuda())  #数据不均衡时可修改损失函数权重
        self.criterion_attention = nn.MSELoss()
        self.optimizer = optim.SGD(self.model.parameters(),
                                   lr=self.config['model']['init_lr'],
                                   momentum=0.9,
                                   weight_decay=1e-4)
        # self.exp_lr_schedler = lr_scheduler.MultiStepLR(self.optimizer, milestones=self.config['model']['milestones'], gamma=0.1)
        # self.exp_lr_schedler = lr_scheduler.StepLR(self.optimizer, step_size=2, gamma=0.9)
        self.exp_lr_schedler = lr_scheduler.StepLR(self.optimizer,
                                                   step_size=10,
                                                   gamma=0.6)
Пример #4
0
def build_model(lr):
    net = model.resnet101(pretrained=False, num_classes=9)

    if torch.cuda.is_available():
        net.cuda()
        torch.backends.cudnn.benchmark = True

    opt = torch.optim.SGD(net.params(), lr)
    
    return net, opt
Пример #5
0
    def load_model(self):
        self.checkpoint = torch.load(self.model_checkpoint_file_path,
                                     map_location=lambda storage, loc: storage)
        self.model_args = self.checkpoint['args']

        self.num_classes = None
        if self.model_args.model_type == 'food179':
            self.num_classes = 179
        elif self.model_args.model_type == 'nsfw':
            self.num_classes = 5
        else:
            raise ('Not Implemented!')

        if self.model_args.model_arc == 'resnet18':
            self.model = model.resnet18(num_classes=self.num_classes,
                                        zero_init_residual=True)
        elif self.model_args.model_arc == 'resnet34':
            self.model = model.resnet34(num_classes=self.num_classes,
                                        zero_init_residual=True)
        elif self.model_args.model_arc == 'resnet50':
            self.model = model.resnet50(num_classes=self.num_classes,
                                        zero_init_residual=True)
        elif self.model_args.model_arc == 'resnet101':
            self.model = model.resnet101(num_classes=self.num_classes,
                                         zero_init_residual=True)
        elif self.model_args.model_arc == 'resnet152':
            self.model = model.resnet152(num_classes=self.num_classes,
                                         zero_init_residual=True)
        elif self.model_args.model_arc == 'mobilenet':
            self.model = model.MobileNetV2(n_class=self.num_classes,
                                           input_size=256)
        else:
            raise ('Not Implemented!')

        self.model = nn.DataParallel(self.model)
        self.model.load_state_dict(self.checkpoint['model_state_dict'])
        self.model_epoch = self.checkpoint['epoch']
        self.model_test_acc = self.checkpoint['test_acc']
        self.model_best_acc = self.checkpoint['best_acc']
        self.model_test_acc_top5 = self.checkpoint['test_acc_top5']
        self.model_class_to_idx = self.checkpoint['class_to_idx']
        self.model_idx_to_class = {
            v: k
            for k, v in self.model_class_to_idx.items()
        }
        self.model_train_history_dict = self.checkpoint['train_history_dict']
        self.mean = self.checkpoint['NORM_MEAN']
        self.std = self.checkpoint['NORM_STD']
        self.model.eval()

        return
Пример #6
0
def main(args=None):

    data_set = {
        x: guipang(cfg=cfg['dataset_guipang'], part=x) for x in ['train', 'val']
    }
    # data_set = {
    #     x: qiafan(cfg=cfg['dataset_qiafan'], part=x) for x in ['train', 'val']
    # }
    data_loader = {
        x: data.DataLoader(data_set[x], batch_size=cfg['batch_size'],
                           num_workers=4, shuffle=True, pin_memory=False)
        for x in ['train', 'val']
    }

	# Create the model
	if cfg['depth'] == 18:
		retinanet = model.resnet18(
		    num_classes=dataset_train.num_classes(), pretrained=True)
	elif cfg['depth'] == 34:
		retinanet = model.resnet34(
		    num_classes=dataset_train.num_classes(), pretrained=True)
	elif cfg['depth'] == 50:
		retinanet = model.resnet50(
		    num_classes=dataset_train.num_classes(), pretrained=True)
	elif cfg['depth'] == 101:
		retinanet = model.resnet101(
		    num_classes=dataset_train.num_classes(), pretrained=True)
	elif cfg['depth'] == 152:
		retinanet = model.resnet152(
		    num_classes=dataset_train.num_classes(), pretrained=True)
	else:
		raise ValueError(
		    'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

	use_gpu = True

	if use_gpu:
		retinanet = retinanet.cuda()

	retinanet = torch.nn.DataParallel(retinanet).cuda()

	optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)
Пример #7
0
    def load_state(self, _type='test'):
        if self.config['model']['name'] == 'resnet50':
            self.model = resnet50(
                pretrained=False,
                num_classes=self.config['model']['num_classes'])
        elif self.config['model']['name'] == 'resnet101':
            self.model = resnet101(
                pretrained=False,
                num_classes=self.config['model']['num_classes'])
        elif self.config['model']['name'] == 'inception_v3':
            self.model = inception_v3(
                pretrained=False,
                num_classes=self.config['model']['num_classes'])
        elif self.config['model']['name'] == 'vgg16':
            self.model = vgg16(pretrained=False,
                               num_classes=self.config['model']['num_classes'])
        elif self.config['model']['name'] == 'vgg19_bn':
            self.model = vgg19_bn(
                pretrained=False,
                num_classes=self.config['model']['num_classes'])
        elif self.config['model']['name'] == 'ws_dan_resnet50':
            self.model = ws_dan_resnet50(
                pretrained=True,
                num_classes=self.config['model']['num_classes'],
                num_attentions=self.config['model']['num_attentions'])

        if _type == 'test':
            # checkpoints = os.path.join('./zhongshan/new_test_file_20200119', self.config['name'], 'checkpoints',
            #                            self.config['test']['checkpoint'])
            checkpoints = os.path.join('.', self.config['test']['checkpoint'])
        elif _type == 'test_batch':
            checkpoints = self.config['test']['checkpoint']
        else:
            checkpoints = os.path.join('./results', self.config['name'],
                                       'checkpoints',
                                       self.config['inference']['checkpoint'])
        self.model.load_state_dict(torch.load(checkpoints)['state_dict'])
        self.model.cuda()
        if len(self.config['gpu_ids']) > 1:
            self.model = nn.DataParallel(self.model)
Пример #8
0
def main(args=None):
    from dataloader import JinNanDataset, Augmenter, UnNormalizer, Normalizer,Resizer
    from torch.utils.data import Dataset, DataLoader
    from torchvision import datasets, models, transforms
    import model
    import torch
    import argparse


    parser     = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')
    parser.add_argument('--dataset',default='jingnan', help='Dataset type, must be one of csv or coco.')
    parser.add_argument('--threshold',help='treshold')
    parser.add_argument('--dataset_path', help='Path to file containing training and validation annotations (optional, see readme)') 
    parser.add_argument('--model_path',help=('the model path'))
    parser.add_argument('--depth', help='Resnet depth, must be one of 18, 34, 50, 101, 152', type=int, default=50)
    parser = parser.parse_args(args)

    dataset_val=JinNanDataset(parser.dataset_path, set_name='val', transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))
    # Create the model
    if parser.depth == 18:
    	retinanet = model.resnet18(num_classes=dataset_val.num_classes(), pretrained=True)
    elif parser.depth == 34:
    	retinanet = model.resnet34(num_classes=dataset_val.num_classes(), pretrained=True)
    elif parser.depth == 50:
    	retinanet = model.resnet50(num_classes=dataset_val.num_classes(), pretrained=True)
    elif parser.depth == 101:
    	retinanet = model.resnet101(num_classes=dataset_val.num_classes(), pretrained=True)
    elif parser.depth == 152:
    	retinanet = model.resnet152(num_classes=dataset_val.num_classes(), pretrained=True)
    else:
    	raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')		

    retinanet=torch.load(parser.model_path)
    use_gpu = True

    if use_gpu:
        retinanet = retinanet.cuda()
    retinanet.eval()
    print('Evaluating dataset')
    evaluate_jinnan(dataset_val, retinanet)
def run_evaluate():
    testset = TestDataset(opt)
    test_dataloader = data_.DataLoader(
        testset,
        batch_size=opt.batch_size,
        num_workers=opt.num_workers,
        shuffle=False  #, \
        #pin_memory=True
    )

    resnet = model.resnet101(20, True)
    resnet = torch.nn.DataParallel(resnet).cuda()

    resnet.load_state_dict(torch.load('Weights/resnet101_relation_e2e_20.pt'))
    resnet.module.use_preset(isTraining=False, preset='evaluate')
    resnet.eval()

    for child in resnet.module.children():
        for param in child.parameters():
            param.requires_grad = False

    print(eval(test_dataloader, resnet, 100))
Пример #10
0
    def build(self, depth=50, learning_rate=1e-5, ratios=[0.5, 1, 2],
              scales=[2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]):
        # Create the model
        if depth == 18:
            retinanet = model.resnet18(num_classes=self.dataset_train.num_classes(), ratios=ratios, scales=scales,
                                       weights_dir=self.weights_dir_path,
                                       pretrained=True)
        elif depth == 34:
            retinanet = model.resnet34(num_classes=self.dataset_train.num_classes(), ratios=ratios, scales=scales,
                                       weights_dir=self.weights_dir_path,
                                       pretrained=True)
        elif depth == 50:
            retinanet = model.resnet50(num_classes=self.dataset_train.num_classes(), ratios=ratios, scales=scales,
                                       weights_dir=self.weights_dir_path,
                                       pretrained=True)
        elif depth == 101:
            retinanet = model.resnet101(num_classes=self.dataset_train.num_classes(), ratios=ratios, scales=scales,
                                        weights_dir=self.weights_dir_path,
                                        pretrained=True)
        elif depth == 152:
            retinanet = model.resnet152(num_classes=self.dataset_train.num_classes(), ratios=ratios, scales=scales,
                                        weights_dir=self.weights_dir_path,
                                        pretrained=True)
        else:
            raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')
        self.retinanet = retinanet.to(device=self.device)
        self.retinanet.training = True
        self.optimizer = optim.Adam(self.retinanet.parameters(), lr=learning_rate)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, patience=3, verbose=True)

        if self.checkpoint is not None:
            self.retinanet.load_state_dict(self.checkpoint['model'])
            self.optimizer.load_state_dict(self.checkpoint['optimizer'])
            self.scheduler.load_state_dict(self.checkpoint['scheduler'])  # TODO: test this, is it done right?
            # TODO is it right to resume_read_trial optimizer and schedular like this???
        self.ratios = ratios
        self.scales = scales
        self.depth = depth
Пример #11
0
    def set_models(self, dataset_train):
        # Create the model
        if self.depth == 18:
            retinanet = model.resnet18(num_classes=dataset_train.num_classes(),
                                       pretrained=True)
        elif self.depth == 34:
            retinanet = model.resnet34(num_classes=dataset_train.num_classes(),
                                       pretrained=True)
        elif self.depth == 50:
            retinanet = model.resnet50(num_classes=dataset_train.num_classes(),
                                       pretrained=True)
        elif self.depth == 101:
            retinanet = model.resnet101(
                num_classes=dataset_train.num_classes(), pretrained=True)
        elif self.depth == 152:
            retinanet = model.resnet152(
                num_classes=dataset_train.num_classes(), pretrained=True)
        else:
            raise ValueError(
                'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

        if torch.cuda.device_count() > 1:
            print("Let's use", torch.cuda.device_count(), "GPUs!")
            # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
            retinanet = nn.DataParallel(retinanet)

        self.retinanet = retinanet.to(self.device)
        self.retinanet.training = True
        self.optimizer = optim.Adam(self.retinanet.parameters(), lr=self.lr)

        # This lr_shceduler reduce the learning rate based on the models's validation loss
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                              patience=3,
                                                              verbose=True)

        self.loss_hist = collections.deque(maxlen=500)
Пример #12
0
def main(args=None):

    parser     = argparse.ArgumentParser(description='Simple testing script for RetinaNet network.')

    parser.add_argument('--dataset', help='Dataset type, must be one of csv or coco.',default = "csv")
    parser.add_argument('--coco_path', help='Path to COCO directory')
    parser.add_argument('--csv_classes', help='Path to file containing class list (see readme)',default="binary_class.csv")
    parser.add_argument('--csv_val', help='Path to file containing validation annotations (optional, see readme)')
    parser.add_argument('--csv_box_annot', help='Path to file containing predicted box annotations ')

    parser.add_argument('--depth', help='Resnet depth, must be one of 18, 34, 50, 101, 152', type=int, default=18)
    parser.add_argument('--epochs', help='Number of epochs', type=int, default=500)
    parser.add_argument('--model', help='Path of .pt file with trained model',default = 'esposallescsv_retinanet_0.pt')
    parser.add_argument('--model_out', help='Path of .pt file with trained model to save',default = 'trained')

    parser.add_argument('--score_threshold', help='Score above which boxes are kept',default=0.15)
    parser.add_argument('--nms_threshold', help='Score above which boxes are kept',default=0.2)
    parser.add_argument('--max_epochs_no_improvement', help='Max epochs without improvement',default=100)
    parser.add_argument('--max_boxes', help='Max boxes to be fed to recognition',default=50)
    parser.add_argument('--seg_level', help='Line or word, to choose anchor aspect ratio',default='line')
    parser.add_argument('--htr_gt_box',help='Train recognition branch with box gt (for debugging)',default=False)
    parser = parser.parse_args(args)
    
    # Create the data loaders

    if parser.dataset == 'csv':


        if parser.csv_classes is None:
            raise ValueError('Must provide --csv_classes when training on COCO,')


        if parser.csv_val is None:
            dataset_val = None
            print('No validation annotations provided.')
        else:
            dataset_val = CSVDataset(train_file=parser.csv_val, class_list=parser.csv_classes, transform=transforms.Compose([Normalizer(), Resizer()]))


        if parser.csv_box_annot is not None:
            box_annot_data = CSVDataset(train_file=parser.csv_box_annot, class_list=parser.csv_classes, transform=transforms.Compose([Normalizer(), Resizer()]))

        else:    
            box_annot_data = None
    else:
        raise ValueError('Dataset type not understood (must be csv or coco), exiting.')

    
    if dataset_val is not None:
        sampler_val = AspectRatioBasedSampler(dataset_val, batch_size=1, drop_last=False)
        dataloader_val = DataLoader(dataset_val, num_workers=0, collate_fn=collater, batch_sampler=sampler_val)

    if box_annot_data is not None:
        sampler_val = AspectRatioBasedSampler(box_annot_data, batch_size=1, drop_last=False)
        dataloader_box_annot = DataLoader(box_annot_data, num_workers=0, collate_fn=collater, batch_sampler=sampler_val)

    else:
        dataloader_box_annot = dataloader_val

    if not os.path.exists('trained_models'):
        os.mkdir('trained_models')

    # Create the model

    alphabet=dataset_val.alphabet
    if os.path.exists(parser.model):
        retinanet = torch.load(parser.model)
    else:
        if parser.depth == 18:
            retinanet = model.resnet18(num_classes=dataset_val.num_classes(), pretrained=True,max_boxes=int(parser.max_boxes),score_threshold=float(parser.score_threshold),seg_level=parser.seg_level,alphabet=alphabet)
        elif parser.depth == 34:
            retinanet = model.resnet34(num_classes=dataset_train.num_classes(), pretrained=True)
        elif parser.depth == 50:
            retinanet = model.resnet50(num_classes=dataset_train.num_classes(), pretrained=True)
        elif parser.depth == 101:
            retinanet = model.resnet101(num_classes=dataset_train.num_classes(), pretrained=True)
        elif parser.depth == 152:
            retinanet = model.resnet152(num_classes=dataset_train.num_classes(), pretrained=True)
        else:
            raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')        
    use_gpu = True

    if use_gpu:
        retinanet = retinanet.cuda()
    
    retinanet = torch.nn.DataParallel(retinanet).cuda()
    
    #retinanet = torch.load('../Documents/TRAINED_MODELS/pytorch-retinanet/esposallescsv_retinanet_99.pt')
    #print "LOADED pretrained MODEL\n\n"
    

    optimizer = optim.Adam(retinanet.parameters(), lr=1e-4)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=4, verbose=True)

    loss_hist = collections.deque(maxlen=500)
    ctc = CTCLoss()
    retinanet.module.freeze_bn()
    best_cer = 1000
    epochs_no_improvement=0
    
    cers=[]    
    retinanet.eval()
    retinanet.module.epochs_only_det = 0
    #retinanet.module.htr_gt_box = False
    
    retinanet.training=False    
    if parser.score_threshold is not None:
        retinanet.module.score_threshold = float(parser.score_threshold) 
    
    '''if parser.dataset == 'csv' and parser.csv_val is not None:

        print('Evaluating dataset')
    '''
    mAP = csv_eval.evaluate(dataset_val, retinanet,score_threshold=retinanet.module.score_threshold)
    aps = []
    for k,v in mAP.items():
        aps.append(v[0])
    print ("VALID mAP:",np.mean(aps))
            
    print("score th",retinanet.module.score_threshold)
    for idx,data in enumerate(dataloader_box_annot):
        print("Eval CER on validation set:",idx,"/",len(dataloader_box_annot),"\r")
        if box_annot_data:
            image_name = box_annot_data.image_names[idx].split('/')[-1].split('.')[-2]
        else:    
            image_name = dataset_val.image_names[idx].split('/')[-1].split('.')[-2]
        #generate_pagexml(image_name,data,retinanet,parser.score_threshold,parser.nms_threshold,dataset_val)
        text_gt_path="/".join(dataset_val.image_names[idx].split('/')[:-1])
        text_gt = os.path.join(text_gt_path,image_name+'.txt')
        f =open(text_gt,'r')
        text_gt_lines=f.readlines()[0]
        transcript_pred = get_transcript(image_name,data,retinanet,retinanet.module.score_threshold,float(parser.nms_threshold),dataset_val,alphabet)
        cers.append(float(editdistance.eval(transcript_pred,text_gt_lines))/len(text_gt_lines))
        print("GT",text_gt_lines)
        print("PREDS SAMPLE:",transcript_pred)
        print("VALID CER:",np.mean(cers),"best CER",best_cer)    
    print("GT",text_gt_lines)
    print("PREDS SAMPLE:",transcript_pred)
    print("VALID CER:",np.mean(cers),"best CER",best_cer)    
Пример #13
0
    shuffle=False,
    target_size=(im_height, im_width),
    class_mode='categorical')
# img, _ = next(train_data_gen)
total_val = val_data_gen.n  # 验证集样本总数

# 获得类别字典
class_indices = train_data_gen.class_indices
# 转换类别字典中键和值的位置
inverse_dict = dict((val, key) for key, val in class_indices.items())
# 将数字标签字典写入json文件:class_indices.json
json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

feature = resnet101(num_classes=5, include_top=False)
# feature.build((None, 224, 224, 3))  # when using subclass model
feature.load_weights('pretrain_weights.ckpt')  # 加载预训练模型
feature.trainable = False  # 训练时冻结与训练模型参数
feature.summary()  # 打印预训练模型参数

# 在原模型后加入两个全连接层,进行自定义5分类
model = tf.keras.Sequential([
    feature,
    tf.keras.layers.GlobalAvgPool2D(),
    tf.keras.layers.Dropout(rate=0.5),
    tf.keras.layers.Dense(1024),
    tf.keras.layers.Dropout(rate=0.5),
    tf.keras.layers.Dense(5),
    tf.keras.layers.Softmax()
])
Пример #14
0
def main(args=None):
#def main(epoch):
	parser     = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')

	parser.add_argument('--dataset', help='Dataset type, must be one of csv or coco.')
	parser.add_argument('--coco_path', help='Path to COCO directory')
	parser.add_argument('--csv_train', help='Path to file containing training annotations (see readme)')
	parser.add_argument('--csv_classes', help='Path to file containing class list (see readme)')
	parser.add_argument('--csv_val', help='Path to file containing validation annotations (optional, see readme)')

	parser.add_argument('--depth', help='Resnet depth, must be one of 18, 34, 50, 101, 152', type=int, default=50)
	parser.add_argument('--epochs', help='Number of epochs', type=int, default=100)

	#parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
	parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number (useful on restarts)')

	parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)')

	parser = parser.parse_args(args)
	#args = parser.parse_args()        
	#parser = parser.parse_args(epoch)

	# Create the data loaders
	if parser.dataset == 'coco':

		if parser.coco_path is None:
			raise ValueError('Must provide --coco_path when training on COCO,')

		dataset_train = CocoDataset(parser.coco_path, set_name='train2017', transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))
		dataset_val = CocoDataset(parser.coco_path, set_name='val2017', transform=transforms.Compose([Normalizer(), Resizer()]))

	elif parser.dataset == 'csv':

		if parser.csv_train is None:
			raise ValueError('Must provide --csv_train when training on COCO,')

		if parser.csv_classes is None:
			raise ValueError('Must provide --csv_classes when training on COCO,')


		dataset_train = CSVDataset(train_file=parser.csv_train, class_list=parser.csv_classes, transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))

		if parser.csv_val is None:
			dataset_val = None
			print('No validation annotations provided.')
		else:
			dataset_val = CSVDataset(train_file=parser.csv_val, class_list=parser.csv_classes, transform=transforms.Compose([Normalizer(), Resizer()]))

	else:
		raise ValueError('Dataset type not understood (must be csv or coco), exiting.')

	sampler = AspectRatioBasedSampler(dataset_train, batch_size=4, drop_last=False)
	dataloader_train = DataLoader(dataset_train, num_workers=3, collate_fn=collater, batch_sampler=sampler)

	if dataset_val is not None:
		sampler_val = AspectRatioBasedSampler(dataset_val, batch_size=1, drop_last=False)
		dataloader_val = DataLoader(dataset_val, num_workers=3, collate_fn=collater, batch_sampler=sampler_val)

	# Create the model
	if parser.depth == 18:
		retinanet = model.resnet18(num_classes=dataset_train.num_classes(), pretrained=True)
	elif parser.depth == 34:
		retinanet = model.resnet34(num_classes=dataset_train.num_classes(), pretrained=True)
	elif parser.depth == 50:
		retinanet = model.resnet50(num_classes=dataset_train.num_classes(), pretrained=True)
	elif parser.depth == 101:
		retinanet = model.resnet101(num_classes=dataset_train.num_classes(), pretrained=True)
	elif parser.depth == 152:
		retinanet = model.resnet152(num_classes=dataset_train.num_classes(), pretrained=True)
	else:
		raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')		

	use_gpu = True

	if use_gpu:
		retinanet = retinanet.cuda()

	#retinanet().load_state_dict(torch.load('/users/wenchi/ghwwc/Pytorch-retinanet-master/resnet50-19c8e357.pth'))
       
	#if True:
           #print('==> Resuming from checkpoint..')
           #checkpoint = torch.load('/users/wenchi/ghwwc/Pytorch-retinanet-master/coco_retinanet_2.pt')
           #retinanet().load_state_dict(checkpoint)
           #best_loss = checkpoint['loss']
           #start_epoch = checkpoint['epoch']
        
	
	retinanet = torch.nn.DataParallel(retinanet).cuda()

	retinanet.training = True

	#optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)
	optimizer = optim.SGD(retinanet.parameters(), lr=1e-5)

	scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

	loss_hist = collections.deque(maxlen=500)

	retinanet.train()
	#retinanet.freeze_bn()               #for train from a middle state
	retinanet.module.freeze_bn()       #for train from the very beginning

	print('Num training images: {}'.format(len(dataset_train)))

	for epoch_num in range(parser.start_epoch, parser.epochs):

		if parser.resume:
		    if os.path.isfile(parser.resume):
                        print("=>loading checkpoint '{}'".format(parser.resume))
                        checkpoint = torch.load(parser.resume)
                        print(parser.start_epoch)
                        #parser.start_epoch = checkpoint['epoch']
                        #retinanet.load_state_dict(checkpoint['state_dict'])
                        retinanet=checkpoint
                        #retinanet.load_state_dict(checkpoint)
                        print(retinanet)
                        #optimizer.load_state_dict(checkpoint)
                        print("=> loaded checkpoint '{}' (epoch {})".format(parser.resume, checkpoint))
		    else:
                        print("=> no checkpoint found at '{}'".format(parser.resume))

		retinanet.train()
		retinanet.freeze_bn()
		#retinanet.module.freeze_bn()

		if parser.dataset == 'coco':

			print('Evaluating dataset')

			coco_eval.evaluate_coco(dataset_val, retinanet)

		elif parser.dataset == 'csv' and parser.csv_val is not None:

			print('Evaluating dataset')

			mAP = csv_eval.evaluate(dataset_val, retinanet)
		
		epoch_loss = []
		
		for iter_num, data in enumerate(dataloader_train):
			try:
				optimizer.zero_grad()

				classification_loss, regression_loss = retinanet([data['img'].cuda().float(), data['annot'].cuda()])

				classification_loss = classification_loss.mean()
				regression_loss = regression_loss.mean()

				loss = classification_loss + regression_loss
				
				if bool(loss == 0):
					continue

				loss.backward()

				torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)

				optimizer.step()

				loss_hist.append(float(loss))

				epoch_loss.append(float(loss))

				print('Epoch: {} | Iteration: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f}'.format(epoch_num, iter_num, float(classification_loss), float(regression_loss), np.mean(loss_hist)))
				
				del classification_loss
				del regression_loss
			except Exception as e:
				print(e)
				continue

		if parser.dataset == 'coco':

			print('Evaluating dataset')

			coco_eval.evaluate_coco(dataset_val, retinanet)

		elif parser.dataset == 'csv' and parser.csv_val is not None:

			print('Evaluating dataset')

			mAP = csv_eval.evaluate(dataset_val, retinanet)

		
		scheduler.step(np.mean(epoch_loss))	

		#torch.save(retinanet.module, '{}_retinanet_101_{}.pt'.format(parser.dataset, epoch_num))
		torch.save(retinanet, '{}_retinanet_dilation_experiment1_{}.pt'.format(parser.dataset, epoch_num))
		name = '{}_retinanet_dilation_experiment1_{}.pt'.format(parser.dataset, epoch_num)
		parser.resume = '/users/wenchi/ghwwc/pytorch-retinanet-master_new/name'

	retinanet.eval()

	torch.save(retinanet, 'model_final_dilation_experiment1.pt'.format(epoch_num))
Пример #15
0
def main(args=None):

    parser = argparse.ArgumentParser(
        description='Simple training script for training a RetinaNet network.')

    parser.add_argument('--coco_path',
                        help='Path to COCO directory',
                        type=str,
                        default='./data/coco')
    parser.add_argument(
        '--depth',
        help='Resnet depth, must be one of 18, 34, 50, 101, 152',
        type=int,
        default=50)
    parser.add_argument('--checkpoint',
                        help='The path to the checkpoint.',
                        type=str,
                        default=None)
    parser.add_argument('--epochs',
                        help='Number of epochs',
                        type=int,
                        default=100)
    parser.add_argument('--batch_size',
                        help='Number of batch',
                        type=int,
                        default=16)
    parser.add_argument('--gpu_ids',
                        help='Gpu parallel',
                        type=str,
                        default='1, 2')

    parser = parser.parse_args(args)

    # Create the data lodaders
    dataset_train = CocoDataset(parser.coco_path,
                                set_name='train2017',
                                transform=transforms.Compose(
                                    [Normalizer(),
                                     Augmenter(),
                                     Resizer()]))
    dataset_val = CocoDataset(parser.coco_path,
                              set_name='val2017',
                              transform=transforms.Compose(
                                  [Normalizer(), Resizer()]))

    sampler = AspectRatioBasedSampler(dataset_train,
                                      batch_size=4,
                                      drop_last=False)
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=16,
                                  collate_fn=collater,
                                  batch_sampler=sampler)

    sampler_val = AspectRatioBasedSampler(dataset_val,
                                          batch_size=1,
                                          drop_last=False)
    dataloader_val = DataLoader(dataset_val,
                                num_workers=3,
                                collate_fn=collater,
                                batch_sampler=sampler_val)

    # Create the model
    if parser.depth == 18:
        retinanet = model.resnet18(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 34:
        retinanet = model.resnet34(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 50:
        retinanet = model.resnet50(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 101:
        retinanet = model.resnet101(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    elif parser.depth == 152:
        retinanet = model.resnet152(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    else:
        raise ValueError(
            'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    use_gpu = True

    if use_gpu:
        retinanet = retinanet.cuda()
    gpu_ids = parser.gpu_ids.split(',')
    device = torch.device("cuda:" + gpu_ids[0])
    torch.cuda.set_device(device)
    gpu_ids = list(map(int, gpu_ids))
    retinanet = torch.nn.DataParallel(retinanet, device_ids=gpu_ids).to(device)

    if parser.checkpoint:
        pretrained = torch.load(parser.checkpoint).state_dict()
        retinanet.module.load_state_dict(pretrained)

    # add tensorboard to record train log
    retinanet.training = True
    writer = SummaryWriter('./log')
    # writer.add_graph(retinanet, input_to_model=[images, labels])

    retinanet.training = True

    optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=3,
                                                     verbose=True)

    loss_hist = collections.deque(maxlen=500)

    retinanet.train()
    retinanet.module.freeze_bn()

    print('Num training images: {}'.format(len(dataset_train)))

    for epoch_num in range(parser.epochs):

        retinanet.train()
        retinanet.module.freeze_bn()

        epoch_loss = []

        for iter_num, data in enumerate(dataloader_train):
            try:
                optimizer.zero_grad()

                classification_loss, regression_loss = retinanet(
                    [data['img'].to(device), data['ann'].to(device)])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()

                loss = classification_loss + regression_loss

                if bool(loss == 0):
                    continue

                loss.backward()

                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)

                optimizer.step()

                loss_hist.append(float(loss))
                writer.add_scalar('Loss/train', loss, iter_num)
                writer.add_scalar('Loss/reg_loss', regression_loss, iter_num)
                writer.add_scalar('Loss/cls_loss', classification_loss,
                                  iter_num)

                epoch_loss.append(float(loss))

                if (iter_num + 1) % 1000 == 0:
                    print('Save model')
                    torch.save(
                        retinanet.module,
                        'COCO_retinanet_epoch{}_iter{}.pt'.format(
                            epoch_num, iter_num))

                print(
                    'Epoch: {} | Iteration: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f}'
                    .format(epoch_num, iter_num, float(classification_loss),
                            float(regression_loss), np.mean(loss_hist)))

                del classification_loss
                del regression_loss
            except Exception as e:
                print(e)
                continue

        print('Evaluating dataset')

        coco_eval.evaluate_coco(dataset_val, retinanet, writer)

        scheduler.step(np.mean(epoch_loss))

        torch.save(retinanet.module, 'COCO_retinanet_{}.pt'.format(epoch_num))

    retinanet.eval()

    torch.save(retinanet, 'model_final.pt'.format(epoch_num))
Пример #16
0
def main(args=None):
    parser = argparse.ArgumentParser(
        description='Simple training script for training a RetinaNet network.')

    parser.add_argument('--dataset',
                        help='Dataset type, must be one of csv or coco.')
    parser.add_argument('--coco_path', help='Path to COCO directory')
    parser.add_argument(
        '--csv_train',
        help='Path to file containing training annotations (see readme)')
    parser.add_argument('--csv_classes',
                        help='Path to file containing class list (see readme)')
    parser.add_argument(
        '--csv_val',
        help=
        'Path to file containing validation annotations (optional, see readme)'
    )
    parser.add_argument(
        '--depth',
        help='Resnet depth, must be one of 18, 34, 50, 101, 152',
        type=int,
        default=50)
    parser.add_argument('--epochs',
                        help='Number of epochs',
                        type=int,
                        default=100)
    parser.add_argument('--optimizer',
                        help='[SGD | Adam]',
                        type=str,
                        default='SGD')
    parser.add_argument('--model', help='Path to model (.pt) file.')
    parser = parser.parse_args(args)

    # Create the data loaders
    print("\n[Phase 1]: Creating DataLoader for {} dataset".format(
        parser.dataset))
    if parser.dataset == 'coco':
        if parser.coco_path is None:
            raise ValueError('Must provide --coco_path when training on COCO,')

        dataset_train = CocoDataset(parser.coco_path,
                                    set_name='train2014',
                                    transform=transforms.Compose(
                                        [Normalizer(),
                                         Augmenter(),
                                         Resizer()]))
        dataset_val = CocoDataset(parser.coco_path,
                                  set_name='val2014',
                                  transform=transforms.Compose(
                                      [Normalizer(), Resizer()]))

    elif parser.dataset == 'csv':
        if parser.csv_train is None:
            raise ValueError('Must provide --csv_train when training on COCO,')

        if parser.csv_classes is None:
            raise ValueError(
                'Must provide --csv_classes when training on COCO,')

        dataset_train = CSVDataset(train_file=parser.csv_train,
                                   class_list=parser.csv_classes,
                                   transform=transforms.Compose(
                                       [Normalizer(),
                                        Augmenter(),
                                        Resizer()]))

        if parser.csv_val is None:
            dataset_val = None
            print('No validation annotations provided.')
        else:
            dataset_val = CSVDataset(train_file=parser.csv_val,
                                     class_list=parser.csv_classes,
                                     transform=transforms.Compose(
                                         [Normalizer(),
                                          Resizer()]))

    else:
        raise ValueError(
            'Dataset type not understood (must be csv or coco), exiting.')

    sampler = AspectRatioBasedSampler(dataset_train,
                                      batch_size=8,
                                      drop_last=False)
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=8,
                                  collate_fn=collater,
                                  batch_sampler=sampler)

    if dataset_val is not None:
        sampler_val = AspectRatioBasedSampler(dataset_val,
                                              batch_size=16,
                                              drop_last=False)
        dataloader_val = DataLoader(dataset_val,
                                    num_workers=8,
                                    collate_fn=collater,
                                    batch_sampler=sampler_val)

    # Create the model
    if parser.depth == 18:
        retinanet = model.resnet18(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 34:
        retinanet = model.resnet34(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 50:
        retinanet = model.resnet50(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 101:
        retinanet = model.resnet101(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    elif parser.depth == 152:
        retinanet = model.resnet152(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    else:
        raise ValueError(
            'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    print('| Num training images: {}'.format(len(dataset_train)))
    print('| Num test images : {}'.format(len(dataset_val)))

    print("\n[Phase 2]: Preparing RetinaNet Detection Model...")
    use_gpu = torch.cuda.is_available()
    if use_gpu:
        device = torch.device('cuda')
        retinanet = retinanet.to(device)

    retinanet = torch.nn.DataParallel(retinanet,
                                      device_ids=range(
                                          torch.cuda.device_count()))
    print("| Using %d GPUs for Train/Validation!" % torch.cuda.device_count())
    retinanet.training = True

    if parser.optimizer == 'Adam':
        optimizer = optim.Adam(retinanet.parameters(),
                               lr=1e-5)  # not mentioned
        print("| Adam Optimizer with Learning Rate = {}".format(1e-5))
    elif parser.optimizer == 'SGD':
        optimizer = optim.SGD(retinanet.parameters(),
                              lr=1e-2,
                              momentum=0.9,
                              weight_decay=1e-4)
        print("| SGD Optimizer with Learning Rate = {}".format(1e-2))
    else:
        raise ValueError('Unsupported Optimizer, must be one of [SGD | Adam]')

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=3,
                                                     verbose=True)
    loss_hist = collections.deque(maxlen=500)

    retinanet.train()
    retinanet.module.freeze_bn(
    )  # Freeze the BN parameters to ImageNet configuration

    # Check if there is a 'checkpoints' path
    if not osp.exists('./checkpoints/'):
        os.makedirs('./checkpoints/')

    print("\n[Phase 3]: Training Model on {} dataset...".format(
        parser.dataset))
    for epoch_num in range(parser.epochs):
        epoch_loss = []
        for iter_num, data in enumerate(dataloader_train):
            try:
                optimizer.zero_grad()
                classification_loss, regression_loss = retinanet(
                    [data['img'].to(device), data['annot']])
                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()
                loss = classification_loss + regression_loss
                if bool(loss == 0):
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.001)
                optimizer.step()
                loss_hist.append(float(loss))
                epoch_loss.append(float(loss))

                sys.stdout.write('\r')
                sys.stdout.write(
                    '| Epoch: {} | Iteration: {}/{} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f}'
                    .format(epoch_num + 1, iter_num + 1, len(dataloader_train),
                            float(classification_loss), float(regression_loss),
                            np.mean(loss_hist)))
                sys.stdout.flush()

                del classification_loss
                del regression_loss

            except Exception as e:
                print(e)
                continue

        print("\n| Saving current best model at epoch {}...".format(epoch_num +
                                                                    1))
        torch.save(
            retinanet.state_dict(),
            './checkpoints/{}_retinanet_{}.pt'.format(parser.dataset,
                                                      epoch_num + 1))

        if parser.dataset == 'coco':
            #print('Evaluating dataset')
            coco_eval.evaluate_coco(dataset_val, retinanet, device)

        elif parser.dataset == 'csv' and parser.csv_val is not None:
            #print('Evaluating dataset')
            mAP = csv_eval.evaluate(dataset_val, retinanet, device)

        scheduler.step(np.mean(epoch_loss))

    retinanet.eval()
    torch.save(retinanet.state_dict(), './checkpoints/model_final.pt')
def main():
    global NORM_MEAN, NORM_STD, coconut_model, train_history_dict

    for arg in vars(args):
        print(str(arg) + ': ' + str(getattr(args, arg)))
    print('=' * 100)

    # Build Model base on dataset and arc
    num_classes = None
    if args.model_type == 'food179':
        num_classes = 179
        NORM_MEAN = FOOD179_MEAN
        NORM_STD = FOOD179_STD
    elif args.model_type == 'nsfw':
        num_classes = 5
        NORM_MEAN = NSFW_MEAN
        NORM_STD = NSFW_STD
    else:
        raise ('Not Implemented!')

    if args.model_arc == 'resnet18':
        coconut_model = model.resnet18(num_classes=num_classes,
                                       zero_init_residual=True)
    elif args.model_arc == 'resnet34':
        coconut_model = model.resnet34(num_classes=num_classes,
                                       zero_init_residual=True)
    elif args.model_arc == 'resnet50':
        coconut_model = model.resnet50(num_classes=num_classes,
                                       zero_init_residual=True)
    elif args.model_arc == 'resnet101':
        coconut_model = model.resnet101(num_classes=num_classes,
                                        zero_init_residual=True)
    elif args.model_arc == 'resnet152':
        coconut_model = model.resnet152(num_classes=num_classes,
                                        zero_init_residual=True)
    elif args.model_arc == 'mobilenet':
        coconut_model = model.MobileNetV2(n_class=num_classes, input_size=256)
    else:
        raise ('Not Implemented!')

    coconut_model = nn.DataParallel(coconut_model)
    if args.cuda:
        coconut_model = coconut_model.cuda()
        torch.backends.benchmark = True
        print("CUDA Enabled")
        gpu_count = torch.cuda.device_count()
        print('Total of %d GPU available' % (gpu_count))
        args.train_batch_size = args.train_batch_size * gpu_count
        args.test_batch_size = args.test_batch_size * gpu_count
        print('args.train_batch_size: %d' % (args.train_batch_size))
        print('args.test_batch_size: %d' % (args.test_batch_size))

    model_parameters = filter(lambda p: p.requires_grad,
                              coconut_model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    print('Total of %d parameters' % (params))
    # Build Training
    start_epoch = 0
    best_acc = 0
    optimizer = None
    scheduler = None
    milestones = [50, 150, 250]
    if args.train_optimizer == 'sgd':
        optimizer = optim.SGD(coconut_model.parameters(),
                              lr=args.lr,
                              momentum=0.9,
                              nesterov=True,
                              weight_decay=args.l2_reg)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=milestones,
                                                   gamma=0.1)
    elif args.train_optimizer == 'adam':
        optimizer = optim.Adam(coconut_model.parameters(), lr=args.lr)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                                   milestones=milestones,
                                                   gamma=0.1)
    elif args.train_optimizer == 'adabound':
        optimizer = adabound.AdaBound(coconut_model.parameters(),
                                      lr=1e-3,
                                      final_lr=0.1)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size=150,
                                                    gamma=0.1,
                                                    last_epoch=-1)

    global_steps = 0
    if not args.start_from_begining:
        filename = args.model_checkpoint_path
        if args.load_gpu_model_on_cpu:
            checkpoint = torch.load(filename,
                                    map_location=lambda storage, loc: storage)
        else:
            checkpoint = torch.load(filename)

        coconut_model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['model_optimizer'])
        best_acc = checkpoint['best_acc']
        train_history_dict = checkpoint['train_history_dict']
        scheduler.optimizer = optimizer  # Not sure if this actually works
        start_epoch = checkpoint['epoch']
        global_steps = checkpoint['global_steps']
        print(filename + ' loaded!')

    data_loaders = load_datasets()
    train_ops(start_epoch=start_epoch,
              model=coconut_model,
              optimizer=optimizer,
              scheduler=scheduler,
              data_loaders=data_loaders,
              best_acc=best_acc,
              global_steps=global_steps)
Пример #18
0
def main(args=None):
    parser = argparse.ArgumentParser(
        description='Simple training script for training a RetinaNet network.')

    parser.add_argument('--dataset',
                        help='Dataset type, must be one of csv or coco.')
    parser.add_argument('--coco_path', help='Path to COCO directory')
    parser.add_argument('--csv_classes',
                        help='Path to file containing class list (see readme)')
    parser.add_argument('--csv_val',
                        help='Path to file containing validation annotations \
                            (optional, see readme)')

    parser.add_argument('--model', help='Path to model (.pt) file.')

    parser = parser.parse_args(args)

    if parser.dataset == 'coco':
        dataset_val = CocoDataset(parser.coco_path,
                                  set_name='val2017',
                                  transform=transforms.Compose(
                                      [Normalizer(), Resizer()]))
    elif parser.dataset == 'csv':
        dataset_val = CSVDataset(train_file=parser.csv_train,
                                 class_list=parser.csv_classes,
                                 transform=transforms.Compose(
                                     [Normalizer(), Resizer()]))
    else:
        raise ValueError(
            'Dataset type not understood (must be csv or coco), exiting.')

    sampler_val = AspectRatioBasedSampler(dataset_val,
                                          batch_size=1,
                                          drop_last=False)
    dataloader_val = DataLoader(dataset_val,
                                num_workers=1,
                                collate_fn=collater,
                                batch_sampler=sampler_val)

    # Create the model
    if parser.depth == 18:
        retinanet = model.resnet18(num_classes=dataset_val.num_classes(),
                                   pretrained=True)
    elif parser.depth == 34:
        retinanet = model.resnet34(num_classes=dataset_val.num_classes(),
                                   pretrained=True)
    elif parser.depth == 50:
        retinanet = model.resnet50(num_classes=dataset_val.num_classes(),
                                   pretrained=True)
    elif parser.depth == 101:
        retinanet = model.resnet101(num_classes=dataset_val.num_classes(),
                                    pretrained=True)
    elif parser.depth == 152:
        retinanet = model.resnet152(num_classes=dataset_val.num_classes(),
                                    pretrained=True)
    else:
        raise ValueError(
            "Unsupported model depth, must be one of 18, 34, 50, 101, 152")
    retinanet.load_state_dict(torch.load(parser.model))

    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    retinanet = retinanet.to(device)

    retinanet.eval()

    unnormalize = UnNormalizer()

    def draw_caption(image, box, caption):

        b = np.array(box).astype(int)
        cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN,
                    1, (0, 0, 0), 2)
        cv2.putText(image, caption, (b[0], b[1] - 10), cv2.FONT_HERSHEY_PLAIN,
                    1, (255, 255, 255), 1)

    for idx, data in enumerate(dataloader_val):

        with torch.no_grad():
            img = data['img'].to(device).float()
            st = time.time()
            scores, classification, transformed_anchors = retinanet(img)
            print('Elapsed time: {}'.format(time.time() - st))
            idxs = np.where(scores > 0.5)
            img = np.array(255 * unnormalize(data['img'][0, :, :, :])).copy()

            img[img < 0] = 0
            img[img > 255] = 255

            img = np.transpose(img, (1, 2, 0))

            img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB)

            for j in range(idxs[0].shape[0]):
                bbox = transformed_anchors[idxs[0][j], :]
                x1 = int(bbox[0])
                y1 = int(bbox[1])
                x2 = int(bbox[2])
                y2 = int(bbox[3])
                label_name = dataset_val.labels[int(
                    classification[idxs[0][j]])]
                draw_caption(img, (x1, y1, x2, y2), label_name)

                cv2.rectangle(img, (x1, y1), (x2, y2),
                              color=(0, 0, 255),
                              thickness=2)
                print(label_name)

            cv2.imshow('img', img)
            cv2.waitKey(0)
Пример #19
0
def main(args=None):
    """
    In current implementation, if test csv is provided, we use that as validation set and combine the val and train csv's 
    as the csv for training.

    If train_all_labeled_data flag is use, then we combine all 3 (if test is provided) for training and use a prespecified learning rate step schedule.
    """

    parser = argparse.ArgumentParser(
        description='Simple training script for training a RetinaNet network.')
    parser.add_argument(
        '--csv_train',
        help='Path to file containing training annotations (see readme)')
    parser.add_argument('--csv_classes',
                        help='Path to file containing class list (see readme)')
    parser.add_argument(
        '--csv_val',
        help=
        'Path to file containing validation annotations (optional, see readme)',
        default=None)
    parser.add_argument(
        '--csv_test',
        help=
        'Path to file containing test annotations (optional, if provided, train & val will be combined for training and test will be used for evaluation)',
        default=None)
    parser.add_argument('--lr', type=float, default=2e-5)
    parser.add_argument(
        '--depth',
        help='Resnet depth, must be one of 18, 34, 50, 101, 152',
        type=int,
        default=101)
    parser.add_argument('--epochs',
                        help='Number of epochs',
                        type=int,
                        default=25)
    parser.add_argument('--model_output_dir', type=str, default='models')
    parser.add_argument(
        '--train_all_labeled_data',
        help=
        'Combine train, val, and test into 1 training set. Will use prespecified learning rate scheduler steps',
        action='store_true')
    parser.add_argument('--resnet-backbone-normalization',
                        choices=['batch_norm', 'group_norm'],
                        type=str,
                        default='batch_norm')

    parser = parser.parse_args(args)

    print('Learning Rate: {}'.format(parser.lr))
    print("Normalization: ", parser.resnet_backbone_normalization)

    # Create folder - will raise error if folder exists
    assert (os.path.exists(parser.model_output_dir) == False)
    os.mkdir(parser.model_output_dir)

    if parser.csv_train is None:
        raise ValueError('Must provide --csv_train when training,')

    if parser.csv_classes is None:
        raise ValueError('Must provide --csv_classes when training,')

    if not parser.csv_val and parser.csv_test:
        raise ValueError(
            "Cannot specify test set without specifying validation set")

    if parser.train_all_labeled_data:
        csv_paths = [parser.csv_train, parser.csv_val, parser.csv_test]
        train_csv = []
        for path in csv_paths:
            if isinstance(path, str):
                train_csv.append(path)
        val_csv = None
    else:
        if parser.csv_train and parser.csv_val and parser.csv_test:
            train_csv = [parser.csv_train, parser.csv_val
                         ]  # Combine train and val sets for training
            val_csv = parser.csv_test
        else:
            train_csv = parser.csv_train
            val_csv = parser.csv_val

    print('loading train data')
    print(train_csv)
    dataset_train = CSVDataset(train_file=train_csv,
                               class_list=parser.csv_classes,
                               transform=transforms.Compose(
                                   [Normalizer(),
                                    Augmenter(),
                                    Resizer()]))
    print(dataset_train.__len__())

    if val_csv is None:
        dataset_val = None
        print('No validation annotations provided.')
    else:
        dataset_val = CSVDataset(train_file=val_csv,
                                 class_list=parser.csv_classes,
                                 transform=transforms.Compose(
                                     [Normalizer(), Resizer()]))

    print('putting data into loader')
    sampler = AspectRatioBasedSampler(dataset_train,
                                      batch_size=2,
                                      drop_last=False)
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=3,
                                  collate_fn=collater,
                                  batch_sampler=sampler)

    if dataset_val is not None:
        sampler_val = AspectRatioBasedSampler(dataset_val,
                                              batch_size=1,
                                              drop_last=False)
        dataloader_val = DataLoader(dataset_val,
                                    num_workers=3,
                                    collate_fn=collater,
                                    batch_sampler=sampler_val)

    # Create the model
    print('creating model')
    if parser.depth == 18:
        retinanet = model.resnet18(
            num_classes=dataset_train.num_classes(),
            pretrained=True,
            normalization=parser.resnet_backbone_normalization)
    elif parser.depth == 34:
        retinanet = model.resnet34(
            num_classes=dataset_train.num_classes(),
            pretrained=True,
            normalization=parser.resnet_backbone_normalization)
    elif parser.depth == 50:
        retinanet = model.resnet50(
            num_classes=dataset_train.num_classes(),
            pretrained=True,
            normalization=parser.resnet_backbone_normalization)
    elif parser.depth == 101:
        retinanet = model.resnet101(
            num_classes=dataset_train.num_classes(),
            pretrained=True,
            normalization=parser.resnet_backbone_normalization)
    elif parser.depth == 152:
        retinanet = model.resnet152(
            num_classes=dataset_train.num_classes(),
            pretrained=True,
            normalization=parser.resnet_backbone_normalization)
    else:
        raise ValueError(
            'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    use_gpu = True

    if use_gpu:
        retinanet = retinanet.cuda()

    retinanet = torch.nn.DataParallel(retinanet).cuda()

    retinanet.training = True

    optimizer = optim.Adam(retinanet.parameters(), lr=parser.lr)

    lr_factor = 0.3
    if not parser.train_all_labeled_data:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         patience=3,
                                                         factor=lr_factor,
                                                         verbose=True)
    else:
        # these milestones are for when using the lung masks - not for unmasked lung data
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[12, 16, 20,
                                   24], gamma=lr_factor)  # masked training
        #scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[14, 18, 22, 26], gamma=lr_factor)

    loss_hist = collections.deque(maxlen=500)

    retinanet.train()
    retinanet.module.freeze_bn()

    #initialize tensorboard
    writer = SummaryWriter(comment=parser.model_output_dir)

    # Augmentation
    seq = iaa.Sequential([
        iaa.Fliplr(0.5),
        iaa.Flipud(0.5),
        iaa.Affine(scale={
            "x": (1.0, 1.2),
            "y": (1.0, 1.2)
        },
                   rotate=(-20, 20),
                   shear=(-4, 4))
    ],
                         random_order=True)

    def augment(data, seq):
        for n, img in enumerate(data['img']):
            # imgaug needs dim in format (H, W, C)
            image = data['img'][n].permute(1, 2, 0).numpy()

            bbs_array = []
            for ann in data['annot'][n]:
                x1, y1, x2, y2, _ = ann
                bbs_array.append(BoundingBox(x1=x1, y1=y1, x2=x2, y2=y2))

            bbs = BoundingBoxesOnImage(bbs_array, shape=image.shape)
            image_aug, bbs_aug = seq(image=image, bounding_boxes=bbs)

            # save augmented image and chage dims to (C, H, W)
            data['img'][n] = torch.tensor(image_aug.copy()).permute(2, 0, 1)

            # save augmented annotations
            for i, bbox in enumerate(bbs_aug.bounding_boxes):
                x1, y1, x2, y2 = bbox.x1, bbox.y1, bbox.x2, bbox.y2
                obj_class = data['annot'][n][i][-1]
                data['annot'][n][i] = torch.tensor([x1, y1, x2, y2, obj_class])

        return data

    print('Num training images: {}'.format(len(dataset_train)))
    dir_training_images = os.path.join(os.getcwd(), writer.log_dir,
                                       'training_images')
    os.mkdir(dir_training_images)

    best_validation_loss = None
    best_validation_map = None

    for epoch_num in range(parser.epochs):

        writer.add_scalar('Train/LR', optimizer.param_groups[0]['lr'],
                          epoch_num)

        retinanet.train()
        retinanet.module.freeze_bn()

        epoch_loss = []

        for iter_num, data in enumerate(dataloader_train):
            try:
                optimizer.zero_grad()

                data = augment(data, seq)

                # save a few training images to see what augmentation looks like
                if iter_num % 100 == 0 and epoch_num == 0:
                    x1, y1, x2, y2, _ = data['annot'][0][0]

                    fig, ax = plt.subplots(1)
                    ax.imshow(data['img'][0][1])
                    rect = patches.Rectangle((x1, y1),
                                             x2 - x1,
                                             y2 - y1,
                                             linewidth=1,
                                             edgecolor='r',
                                             facecolor='none',
                                             alpha=1)
                    ax.add_patch(rect)
                    fig.savefig(
                        os.path.join(dir_training_images,
                                     '{}.png'.format(iter_num)))
                    plt.close()

                classification_loss, regression_loss = retinanet(
                    [data['img'].cuda().float(), data['annot']])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()

                loss = classification_loss + regression_loss

                if bool(loss == 0):
                    continue

                loss.backward()

                if parser.resnet_backbone_normalization == 'batch_norm':
                    torch.nn.utils.clip_grad_norm_(
                        parameters=retinanet.parameters(), max_norm=0.1)
                else:
                    torch.nn.utils.clip_grad_norm_(
                        parameters=retinanet.parameters(), max_norm=0.01
                    )  # Decrease norm to reduce risk of exploding gradients

                optimizer.step()

                loss_hist.append(float(loss))

                epoch_loss.append(float(loss))

                print(
                    'Epoch: {} | Iteration: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f}'
                    .format(epoch_num, iter_num, float(classification_loss),
                            float(regression_loss), np.mean(loss_hist)))

                del classification_loss
                del regression_loss
            except Exception as e:
                print(e)
                continue

        writer.add_scalar('Train/Loss', np.mean(epoch_loss), epoch_num)

        if not parser.train_all_labeled_data:
            print('Evaluating Validation Loss...')
            with torch.no_grad():
                retinanet.train()
                val_losses, val_class_losses, val_reg_losses = [], [], []
                for val_iter_num, val_data in enumerate(dataloader_val):
                    try:
                        val_classification_loss, val_regression_loss = retinanet(
                            [
                                val_data['img'].cuda().float(),
                                val_data['annot']
                            ])
                        val_losses.append(
                            float(val_classification_loss) +
                            float(val_regression_loss))
                        val_class_losses.append(float(val_classification_loss))
                        val_reg_losses.append(float(val_regression_loss))
                        del val_classification_loss, val_regression_loss
                    except Exception as e:
                        print(e)
                        continue
                print(
                    'VALIDATION Epoch: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Total loss: {:1.5f}'
                    .format(epoch_num, np.mean(val_class_losses),
                            np.mean(val_reg_losses), np.mean(val_losses)))

                # Save model with best validation loss
                if best_validation_loss is None:
                    best_validation_loss = np.mean(val_losses)
                if best_validation_loss >= np.mean(val_losses):
                    best_validation_loss = np.mean(val_losses)
                    torch.save(
                        retinanet.module,
                        parser.model_output_dir + '/best_result_valloss.pt')

                writer.add_scalar('Validation/Loss', np.mean(val_losses),
                                  epoch_num)

                # Calculate Validation mAP
                print('Evaluating validation mAP')
                mAP = csv_eval.evaluate(dataset_val, retinanet)
                print("Validation mAP: " + str(mAP[0][0]))
                if best_validation_map is None:
                    best_validation_map = mAP[0][0]
                elif best_validation_map < mAP[0][0]:
                    best_validation_map = mAP[0][0]
                    torch.save(
                        retinanet.module,
                        parser.model_output_dir + '/best_result_valmAP.pt')

                writer.add_scalar('Validation/mAP', mAP[0][0], epoch_num)

        if not parser.train_all_labeled_data:
            scheduler.step(np.mean(val_losses))
        else:
            scheduler.step()

        torch.save(
            retinanet.module,
            parser.model_output_dir + '/retinanet_{}.pt'.format(epoch_num))

    retinanet.eval()

    torch.save(retinanet, parser.model_output_dir + '/model_final.pt')
Пример #20
0
def main(args=None):

    parser = argparse.ArgumentParser(
        description='Simple training script for training a RetinaNet network.')

    parser.add_argument('--dataset',
                        help='Dataset type, must be one of csv or coco.')
    parser.add_argument('--coco_path', help='Path to COCO directory')
    parser.add_argument(
        '--csv_train',
        help='Path to file containing training annotations (see readme)')
    parser.add_argument('--csv_classes',
                        help='Path to file containing class list (see readme)')
    parser.add_argument(
        '--csv_val',
        help=
        'Path to file containing validation annotations (optional, see readme)'
    )

    parser.add_argument(
        '--depth',
        help='Resnet depth, must be one of 18, 34, 50, 101, 152',
        type=int,
        default=50)
    parser.add_argument('--epochs',
                        help='Number of epochs',
                        type=int,
                        default=100)

    parser = parser.parse_args(args)

    # Create the data loaders
    if parser.dataset == 'coco':

        if parser.coco_path is None:
            raise ValueError('Must provide --coco_path when training on COCO,')

        dataset_train = CocoDataset(parser.coco_path,
                                    set_name='train2017',
                                    transform=transforms.Compose(
                                        [Normalizer(),
                                         Augmenter(),
                                         Resizer()]))
        dataset_val = CocoDataset(parser.coco_path,
                                  set_name='val2017',
                                  transform=transforms.Compose(
                                      [Normalizer(), Resizer()]))

    elif parser.dataset == 'csv':

        if parser.csv_train is None:
            raise ValueError('Must provide --csv_train when training on COCO,')

        if parser.csv_classes is None:
            raise ValueError(
                'Must provide --csv_classes when training on COCO,')

        dataset_train = CSVDataset(train_file=parser.csv_train,
                                   class_list=parser.csv_classes,
                                   transform=transforms.Compose(
                                       [Normalizer(),
                                        Augmenter(),
                                        Resizer()]))

        if parser.csv_val is None:
            dataset_val = None
            print('No validation annotations provided.')
        else:
            dataset_val = CSVDataset(train_file=parser.csv_val,
                                     class_list=parser.csv_classes,
                                     transform=transforms.Compose(
                                         [Normalizer(),
                                          Resizer()]))

    else:
        raise ValueError(
            'Dataset type not understood (must be csv or coco), exiting.')

    sampler = AspectRatioBasedSampler(dataset_train,
                                      batch_size=2,
                                      drop_last=False)
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=3,
                                  collate_fn=collater,
                                  batch_sampler=sampler)

    if dataset_val is not None:
        sampler_val = AspectRatioBasedSampler(dataset_val,
                                              batch_size=1,
                                              drop_last=False)
        dataloader_val = DataLoader(dataset_val,
                                    num_workers=3,
                                    collate_fn=collater,
                                    batch_sampler=sampler_val)

    # Create the model
    if parser.depth == 18:
        retinanet = model.resnet18(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 34:
        retinanet = model.resnet34(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 50:
        retinanet = model.resnet50(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 101:
        retinanet = model.resnet101(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    elif parser.depth == 152:
        retinanet = model.resnet152(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    else:
        raise ValueError(
            'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    use_gpu = True

    if use_gpu:
        retinanet = retinanet.cuda()

    retinanet = torch.nn.DataParallel(retinanet).cuda()

    retinanet.training = True

    optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=3,
                                                     verbose=True)

    loss_hist = collections.deque(maxlen=500)

    retinanet.train()
    retinanet.module.freeze_bn()

    print('Num training images: {}'.format(len(dataset_train)))

    for epoch_num in range(parser.epochs):

        retinanet.train()
        retinanet.module.freeze_bn()

        epoch_loss = []

        for iter_num, data in enumerate(dataloader_train):
            try:
                optimizer.zero_grad()

                classification_loss, regression_loss = retinanet(
                    [data['img'].cuda().float(), data['annot']])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()

                loss = classification_loss + regression_loss

                if bool(loss == 0):
                    continue

                loss.backward()

                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)

                optimizer.step()

                loss_hist.append(float(loss))

                epoch_loss.append(float(loss))

                print(
                    'Epoch: {} | Iteration: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f}'
                    .format(epoch_num, iter_num, float(classification_loss),
                            float(regression_loss), np.mean(loss_hist)))

                del classification_loss
                del regression_loss
            except Exception as e:
                print(e)
                continue

        if parser.dataset == 'coco':

            print('Evaluating dataset')

            coco_eval.evaluate_coco(dataset_val, retinanet)

        elif parser.dataset == 'csv' and parser.csv_val is not None:

            print('Evaluating dataset')

            mAP = csv_eval.evaluate(dataset_val, retinanet)

        scheduler.step(np.mean(epoch_loss))

        torch.save(
            retinanet.module,
            '{}_retinanet_dilation_{}.pt'.format(parser.dataset, epoch_num))

    retinanet.eval()

    torch.save(retinanet, 'model_final_dilation.pt'.format(epoch_num))
Пример #21
0
def train_net(noise_fraction, 
              lr=1e-3,
              momentum=0.9, 
              batch_size=128,
              dir_img='ISIC_2019_Training_Input/',
              save_cp=True,
              dir_checkpoint='checkpoints/ISIC_2019_Training_Input/',
              epochs=10):

    train = BasicDataset(dir_img, noise_fraction, mode='train')
    test = BasicDataset(dir_img, noise_fraction, mode='test')
    val = BasicDataset(dir_img, noise_fraction, mode='val')
    # n_test = int(len(dataset) * test_percent)
    # n_train = len(dataset) - n_val
    # train, test = random_split(dataset, [n_train, n_test])
    data_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
    test_loader = DataLoader(test, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
    val_loader = DataLoader(val, batch_size=5, shuffle=False, num_workers=8, pin_memory=True)
    
    # data_loader = get_mnist_loader(hyperparameters['batch_size'], classes=[9, 4], proportion=0.995, mode="train")
    # test_loader = get_mnist_loader(hyperparameters['batch_size'], classes=[9, 4], proportion=0.5, mode="test")

    val_data, val_labels = next(iter(val_loader))
    val_data = to_var(val_data, requires_grad=False)
    val_labels = to_var(val_labels, requires_grad=False)

    data = iter(data_loader)
    
    net, opt = build_model(lr)
    plot_step = 100
    accuracy_log = []

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Checkpoints:     {save_cp}
        Noise fraction:  {noise_fraction}
        Image dir:       {dir_img}
    ''')

    for epoch in range(epochs):
        net.train()
        for i in tqdm(range(int(len(train)/batch_size))):
            # Line 2 get batch of data
            image, labels = next(data)
            try:
            image, labels = next(data)
            except StopIteration:
                data = iter(data_loader)
                image, labels = next(data)
            # image, labels = next(iter(data_loader))
            # since validation data is small I just fixed them instead of building an iterator
            # initialize a dummy network for the meta learning of the weights
            meta_net = model.resnet101(pretrained=False, num_classes=9)
            meta_net.load_state_dict(net.state_dict())

            if torch.cuda.is_available():
                meta_net.cuda()

            image = to_var(image, requires_grad=False)
            labels = to_var(labels, requires_grad=False)

            # Lines 4 - 5 initial forward pass to compute the initial weighted loss
            # with torch.no_grad():
                # print(image.shape)
            y_f_hat = meta_net(image)
            
            labels = labels.float()
            cost = F.binary_cross_entropy_with_logits(y_f_hat, labels, reduce=False)
            # print('cost:', cost)
            eps = to_var(torch.zeros(cost.size()))
            # print('eps: ', eps)
            l_f_meta = torch.sum(cost * eps)

            meta_net.zero_grad()

            # Line 6 perform a parameter update
            grads = torch.autograd.grad(l_f_meta, (meta_net.params()), create_graph=True, allow_unused=True)
            meta_net.update_params(lr, source_params=grads)
            
            # Line 8 - 10 2nd forward pass and getting the gradients with respect to epsilon
            # with torch.no_grad():
            y_g_hat = meta_net(val_data)

            val_labels = val_labels.float()
            l_g_meta = F.binary_cross_entropy_with_logits(y_g_hat, val_labels)

            grad_eps = torch.autograd.grad(l_g_meta, eps, only_inputs=True)[0]
            
            # Line 11 computing and normalizing the weights
            w_tilde = torch.clamp(-grad_eps, min=0)
            norm_c = torch.sum(w_tilde)

            if norm_c != 0:
                w = w_tilde / norm_c
            else:
                w = w_tilde

            # Lines 12 - 14 computing for the loss with the computed weights
            # and then perform a parameter update
            # with torch.no_grad():
            y_f_hat = net(image)

            labels = labels.float()
            cost = F.binary_cross_entropy_with_logits(y_f_hat, labels, reduce=False)
            l_f = torch.sum(cost * w)

            opt.zero_grad()
            l_f.backward()
            opt.step()
            
            if i % plot_step == 0:
                net.eval()

                acc = []
                for i, (test_img, test_label) in enumerate(test_loader):
                    test_img = to_var(test_img, requires_grad=False)
                    test_label = to_var(test_label, requires_grad=False)

                    with torch.no_grad():
                        output = net(test_img)
                    predicted = (F.sigmoid(output) > 0.5)
                    # print(type(predicted))
                    # predicted = to_var(predicted, requires_grad=False)
                    # print(type(predicted))
                    # test_label = test_label.float()

                    # print(type((predicted == test_label).float()))
                    acc.append((predicted.float() == test_label.float()).float())

                accuracy = torch.cat(acc, dim=0).mean()
                accuracy_log.append(np.array([i, accuracy])[None])
                acc_log = np.concatenate(accuracy_log, axis=0)
                
            if save_cp:
                try:
                    os.mkdir(dir_checkpoint)
                    logging.info('Created checkpoint directory')
                except OSError:
                    pass
                torch.save(net.state_dict(),
                           dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
                logging.info(f'Checkpoint {epoch + 1} saved !')

        # return accuracy
    return np.mean(acc_log[-6:-1, 1])
Пример #22
0
                       download=False,
                       transform=transforms_test)
train_loader = DataLoader(dataset_train,
                          batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=args.num_worker)
test_loader = DataLoader(dataset_test,
                         batch_size=args.batch_size_test,
                         shuffle=False,
                         num_workers=args.num_worker)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
           'ship', 'truck')

print('==> Making model..')

net = model.resnet101()
net = net.to(device)
num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('The number of parameters of model is', num_params)

if args.resume is not None:
    checkpoint = torch.load('./save_model/' + args.resume)
    net.load_state_dict(checkpoint['net'])

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),
                      lr=0.1,
                      momentum=0.9,
                      weight_decay=1e-4)

decay_epoch = [32000, 48000]
Пример #23
0
def main(args=None):

    parser = argparse.ArgumentParser(
        description="Simple training script for training a RetinaNet network.")

    parser.add_argument(
        "--dataset", help="Dataset type, must be one of csv or coco or ycb.")
    parser.add_argument("--path", help="Path to dataset directory")
    parser.add_argument(
        "--csv_train",
        help="Path to file containing training annotations (see readme)")
    parser.add_argument("--csv_classes",
                        help="Path to file containing class list (see readme)")
    parser.add_argument("--csv_val",
                        help="Path to file containing validation annotations "
                        "(optional, see readme)")

    parser.add_argument(
        "--depth",
        help="Resnet depth, must be one of 18, 34, 50, 101, 152",
        type=int,
        default=50)
    parser.add_argument("--epochs",
                        help="Number of epochs",
                        type=int,
                        default=100)
    parser.add_argument("--evaluate_every", default=20, type=int)
    parser.add_argument("--print_every", default=20, type=int)
    parser.add_argument('--distributed',
                        action="store_true",
                        help='Run model in distributed mode with DataParallel')

    parser = parser.parse_args(args)

    # Create the data loaders
    if parser.dataset == "coco":

        if parser.path is None:
            raise ValueError(
                "Must provide --path when training on non-CSV datasets")

        dataset_train = CocoDataset(parser.path,
                                    ann_file="instances_train2014.json",
                                    set_name="train2014",
                                    transform=transforms.Compose([
                                        Normalizer(),
                                        Augmenter(),
                                        Resizer(min_side=512, max_side=512)
                                    ]))
        dataset_val = CocoDataset(parser.path,
                                  ann_file="instances_val2014.cars.json",
                                  set_name="val2014",
                                  transform=transforms.Compose(
                                      [Normalizer(), Resizer()]))

    elif parser.dataset == "ycb":

        dataset_train = YCBDataset(parser.path,
                                   "image_sets/train.txt",
                                   transform=transforms.Compose([
                                       Normalizer(),
                                       Augmenter(),
                                       Resizer(min_side=512, max_side=512)
                                   ]),
                                   train=True)
        dataset_val = YCBDataset(parser.path,
                                 "image_sets/val.txt",
                                 transform=transforms.Compose(
                                     [Normalizer(), Resizer()]),
                                 train=False)

    elif parser.dataset == "csv":

        if parser.csv_train is None:
            raise ValueError("Must provide --csv_train when training on COCO,")

        if parser.csv_classes is None:
            raise ValueError(
                "Must provide --csv_classes when training on COCO,")

        dataset_train = CSVDataset(train_file=parser.csv_train,
                                   class_list=parser.csv_classes,
                                   transform=transforms.Compose(
                                       [Normalizer(),
                                        Augmenter(),
                                        Resizer()]))

        if parser.csv_val is None:
            dataset_val = None
            print("No validation annotations provided.")
        else:
            dataset_val = CSVDataset(train_file=parser.csv_val,
                                     class_list=parser.csv_classes,
                                     transform=transforms.Compose(
                                         [Normalizer(),
                                          Resizer()]))

    else:
        raise ValueError(
            "Dataset type not understood (must be csv or coco), exiting.")

    sampler = AspectRatioBasedSampler(dataset_train,
                                      batch_size=12,
                                      drop_last=False)
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=8,
                                  collate_fn=collater,
                                  batch_sampler=sampler)

    if dataset_val is not None:
        sampler_val = AspectRatioBasedSampler(dataset_val,
                                              batch_size=1,
                                              drop_last=False)
        dataloader_val = DataLoader(dataset_val,
                                    num_workers=4,
                                    collate_fn=collater,
                                    batch_sampler=sampler_val)

    # Create the model
    if parser.depth == 18:
        retinanet = model.resnet18(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 34:
        retinanet = model.resnet34(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 50:
        retinanet = model.resnet50(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 101:
        retinanet = model.resnet101(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    elif parser.depth == 152:
        retinanet = model.resnet152(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    else:
        raise ValueError(
            "Unsupported model depth, must be one of 18, 34, 50, 101, 152")

    print("CUDA available: {}".format(torch.cuda.is_available()))
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    retinanet = retinanet.to(device)

    if parser.distributed:
        retinanet = torch.nn.DataParallel(retinanet)

    optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=3,
                                                     verbose=True)

    loss_hist = collections.deque(maxlen=500)

    print("Num training images: {}".format(len(dataset_train)))

    best_mean_avg_prec = 0.0

    for epoch_num in range(parser.epochs):

        retinanet.train()
        retinanet.freeze_bn()

        epoch_loss = []

        for iter_num, data in enumerate(dataloader_train):
            try:
                optimizer.zero_grad()

                classification_loss, regression_loss = retinanet(
                    [data["img"].to(device).float(), data["annot"]])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()

                loss = classification_loss + regression_loss

                if bool(loss == 0):
                    continue

                loss.backward()

                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)

                optimizer.step()

                loss_hist.append(float(loss.item()))
                epoch_loss.append(float(loss.item()))

                if parser.print_every % iter_num == 0:
                    print("Epoch: {} | Iteration: {}/{} | "
                          "Classification loss: {:1.5f} | "
                          "Regression loss: {:1.5f} | "
                          "Running loss: {:1.5f}".format(
                              epoch_num, iter_num, len(dataloader_train),
                              float(classification_loss),
                              float(regression_loss), np.mean(loss_hist)))

                del classification_loss
                del regression_loss
            except Exception as e:
                print(e)
                continue

        if ((epoch_num + 1) % parser.evaluate_every
                == 0) or epoch_num + 1 == parser.epochs:

            mAP = 0.0

            if parser.dataset == "coco":

                print("Evaluating dataset")
                mAP = coco_eval.evaluate_coco(dataset_val, retinanet)

            else:
                print("Evaluating dataset")
                AP = eval.evaluate(dataset_val, retinanet)
                mAP = np.asarray([x[0] for x in AP.values()]).mean()
                print("Val set mAP: ", mAP)

            if mAP > best_mean_avg_prec:
                best_mean_avg_prec = mAP
                torch.save(
                    retinanet.state_dict(),
                    "{}_retinanet_best_mean_ap_{}.pt".format(
                        parser.dataset, epoch_num))

        scheduler.step(np.mean(epoch_loss))

    retinanet.eval()

    torch.save(retinanet.state_dict(), "retinanet_model_final.pt")
Пример #24
0
def main(config):
    # set seed for reproducibility
    np.random.seed(0)
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    # create folder for model
    newpath = './models/' + config.model_date
    if config.save_model:
        os.makedirs(newpath)

    # Create the data loaders
    if config.csv_train is None:
        raise ValueError('Must provide --csv_train when training on csv,')

    if config.csv_classes is None:
        raise ValueError('Must provide --csv_classes when training on csv,')

    train_dataset = datasets.ImageFolder(os.path.join(config.data_dir,
                                                      'train'))
    dataset_train = GetDataset(train_file=config.csv_train,
                               class_list=config.csv_classes,
                               transform=transforms.Compose(
                                   [Augmenter(), Resizer()]),
                               dataset=train_dataset,
                               seed=0)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=1,
                                  collate_fn=collater)

    if config.csv_val is None:
        dataset_val = None
        print('No validation annotations provided.')
    else:
        valid_dataset = datasets.ImageFolder(
            os.path.join(config.data_dir, 'valid'))
        dataset_val = GetDataset(train_file=config.csv_val,
                                 class_list=config.csv_classes,
                                 transform=transforms.Compose([Resizer()]),
                                 dataset=valid_dataset,
                                 seed=0)

    # Create the model
    if config.depth == 18:
        retinanet = model.resnet18(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif config.depth == 34:
        retinanet = model.resnet34(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif config.depth == 50:
        retinanet = model.resnet50(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif config.depth == 101:
        retinanet = model.resnet101(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    elif config.depth == 152:
        retinanet = model.resnet152(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    else:
        raise ValueError(
            'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    if config.use_gpu:
        retinanet = retinanet.cuda()

    retinanet = torch.nn.DataParallel(retinanet).cuda()

    retinanet.training = True

    optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=3,
                                                     verbose=True)

    retinanet.train()
    retinanet.module.freeze_bn()

    print('Num training images: {}'.format(len(dataset_train)))

    best_valid_map = 0
    counter = 0
    batch_size = config.batch_size

    for epoch_num in range(config.epochs):
        print('\nEpoch: {}/{}'.format(epoch_num + 1, config.epochs))
        retinanet.train()
        retinanet.module.freeze_bn()

        epoch_loss = []

        train_batch_time = AverageMeter()
        train_losses = AverageMeter()
        tic = time.time()
        with tqdm(total=len(dataset_train)) as pbar:
            for iter_num, data in enumerate(dataloader_train):
                # try:
                optimizer.zero_grad()
                siamese_loss, classification_loss, regression_loss = retinanet(
                    [
                        data['img'].cuda().float(), data['annot'],
                        data['pair'].cuda().float()
                    ])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()
                loss = classification_loss + regression_loss + siamese_loss

                if bool(loss == 0):
                    continue

                loss.backward()
                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)
                optimizer.step()
                epoch_loss.append(float(loss))

                toc = time.time()
                train_losses.update(float(loss), batch_size)
                train_batch_time.update(toc - tic)
                tic = time.time()

                pbar.set_description(("{:.1f}s - loss: {:.3f}".format(
                    train_batch_time.val,
                    train_losses.val,
                )))
                pbar.update(batch_size)

                del classification_loss
                del regression_loss
                del siamese_loss

                # except Exception as e:
                #     print('Training error: ', e)
                #     continue

        if config.csv_val is not None:
            print('Evaluating dataset')
            mAP, correct = eval_new.evaluate(dataset_val, retinanet)

            # is_best = mAP[0][0] > best_valid_map
            # best_valid_map = max(mAP[0][0], best_valid_map)
            is_best = correct > best_valid_map
            best_valid_map = max(correct, best_valid_map)
            if is_best:
                counter = 0
            else:
                counter += 1
                if counter > 3:
                    print("[!] No improvement in a while, stopping training.")
                    break

        scheduler.step(np.mean(epoch_loss))
        if is_best and config.save_model:
            torch.save(
                retinanet.state_dict(),
                './models/{}/best_retinanet.pt'.format(config.model_date))
        if config.save_model:
            torch.save(
                retinanet.state_dict(),
                './models/{}/{}_retinanet_{}.pt'.format(
                    config.model_date, config.depth, epoch_num))

        msg = "train loss: {:.3f} - val map: {:.3f} - val acc: {:.3f}%"
        print(
            msg.format(train_losses.avg, mAP[0][0],
                       (100. * correct) / len(dataset_val)))
def main(args=None):

    parser     = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')

    parser.add_argument('--dataset',default="csv", help='Dataset type, must be one of csv or coco.')
    parser.add_argument('--coco_path',default="/home/mayank-s/PycharmProjects/Datasets/coco",help='Path to COCO directory')
    parser.add_argument('--csv_train',default="berkely_ready_to_train_for_retinanet_pytorch.csv", help='Path to file containing training annotations (see readme)')
    parser.add_argument('--csv_classes',default="berkely_class.csv", help='Path to file containing class list (see readme)')
    parser.add_argument('--csv_val', help='Path to file containing validation annotations (optional, see readme)')

    parser.add_argument('--depth', help='Resnet depth, must be one of 18, 34, 50, 101, 152', type=int, default=50)
    parser.add_argument('--epochs', help='Number of epochs', type=int, default=200)
    # parser.add_argument('--resume', default=0, help='resume from checkpoint')
    parser = parser.parse_args(args)
    # print(args.resume)

    # Create the data loaders
    if parser.dataset == 'coco':

        if parser.coco_path is None:
            raise ValueError('Must provide --coco_path when training on COCO,')

        dataset_train = CocoDataset(parser.coco_path, set_name='train2014', transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))
        dataset_val = CocoDataset(parser.coco_path, set_name='val2014', transform=transforms.Compose([Normalizer(), Resizer()]))

    elif parser.dataset == 'csv':

        if parser.csv_train is None:
            raise ValueError('Must provide --csv_train when training on COCO,')

        if parser.csv_classes is None:
            raise ValueError('Must provide --csv_classes when training on COCO,')


        dataset_train = CSVDataset(train_file=parser.csv_train, class_list=parser.csv_classes, transform=transforms.Compose([Normalizer(), Augmenter(), Resizer()]))

        if parser.csv_val is None:
            dataset_val = None
            print('No validation annotations provided.')
        else:
            dataset_val = CSVDataset(train_file=parser.csv_val, class_list=parser.csv_classes, transform=transforms.Compose([Normalizer(), Resizer()]))

    else:
        raise ValueError('Dataset type not understood (must be csv or coco), exiting.')

    sampler = AspectRatioBasedSampler(dataset_train, batch_size=4, drop_last=False)
    dataloader_train = DataLoader(dataset_train, num_workers=0, collate_fn=collater, batch_sampler=sampler)

    if dataset_val is not None:
        sampler_val = AspectRatioBasedSampler(dataset_val, batch_size=1, drop_last=False)
        dataloader_val = DataLoader(dataset_val, num_workers=3, collate_fn=collater, batch_sampler=sampler_val)

    # Create the model
    if parser.depth == 18:
        retinanet = model.resnet18(num_classes=dataset_train.num_classes(), pretrained=True)
    elif parser.depth == 34:
        retinanet = model.resnet34(num_classes=dataset_train.num_classes(), pretrained=True)
    elif parser.depth == 50:
        retinanet = model.resnet50(num_classes=dataset_train.num_classes(), pretrained=True)
    elif parser.depth == 101:
        retinanet = model.resnet101(num_classes=dataset_train.num_classes(), pretrained=True)
    elif parser.depth == 152:
        retinanet = model.resnet152(num_classes=dataset_train.num_classes(), pretrained=True)
    else:
        raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    use_gpu = True

    # if use_gpu:
    if torch.cuda.is_available():
        retinanet = retinanet.cuda()

        retinanet = torch.nn.DataParallel(retinanet).cuda()

        retinanet.training = True
    ###################################################################################3
    # # args.resume=0
    # Resume_model = False
    # start_epoch=0
    # if Resume_model:
    #     print('==> Resuming from checkpoint..')
    #     checkpoint = torch.load('./checkpoint/saved_with_epochs/retina_fpn_1')
    #     retinanet.load_state_dict(checkpoint['net'])
    #     best_loss = checkpoint['loss']
    #     start_epoch = checkpoint['epoch']
    #     print('Resuming from epoch:{ep}  loss:{lp}'.format(ep=start_epoch, lp=best_loss))
    #####################################################################################
    optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)

    loss_hist = collections.deque(maxlen=500)

    retinanet.train()
    retinanet.module.freeze_bn()

    print('Num training images: {}'.format(len(dataset_train)))

    retinanet = torch.load("./checkpoint/retina_fpn_1")

    # epoch_num=start_epoch
    for epoch_num in range(parser.epochs):

        # retinanet.train()retina_fpn_1
        # retinanet.module.freeze_bn()

        epoch_loss = []

        for iter_num, data in enumerate(dataloader_train):
            try:
                optimizer.zero_grad()
                if torch.cuda.is_available():
                    classification_loss, regression_loss = retinanet([data['img'].cuda().float(), data['annot']])
                else:
                    classification_loss, regression_loss = retinanet([data['img'].float(), data['annot']])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()

                loss = classification_loss + regression_loss

                if bool(loss == 0):
                    continue

                loss.backward()

                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)

                optimizer.step()

                loss_hist.append(float(loss))

                epoch_loss.append(float(loss))

                print('Epoch: {} | Iteration: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f}'.format(epoch_num, iter_num, float(classification_loss), float(regression_loss), np.mean(loss_hist)))

                del classification_loss
                del regression_loss
            except Exception as e:
                print(e)
                continue

        # print("Saving model...")
        # name = "./checkpoint/retina_fpn_" + str(epoch_num)
        # torch.save(retinanet, name)
        # ###################################################################333
        print('Saving..')
        state = {
            'net': retinanet.module.state_dict(),
            'loss': loss_hist,
            'epoch': epoch_num,
        }
        if not os.path.isdir('checkpoint/saved_with_epochs'):
            os.mkdir('checkpoint/saved_with_epochs')
        # checkpoint_path="./checkpoint/Ckpt_"+
        name = "./checkpoint/saved_with_epochs/retina_fpn_" + str(epoch_num)
        torch.save(state, name)
        # torch.save(state, './checkpoint/retinanet.pth')
        #####################################################################

        '''if parser.dataset == 'coco':
Пример #26
0
def main(args=None):

    parser = argparse.ArgumentParser(
        description='Simple training script for training a RetinaNet network.')

    parser.add_argument('--dataset',
                        default="csv",
                        help='Dataset type, must be one of csv or coco.')
    parser.add_argument('--coco_path', help='Path to COCO directory')
    parser.add_argument(
        '--csv_train',
        default="./data/train_only.csv",
        help='Path to file containing training annotations (see readme)')
    parser.add_argument('--csv_classes',
                        default="./data/classes.csv",
                        help='Path to file containing class list (see readme)')
    parser.add_argument(
        '--csv_val',
        default="./data/train_only.csv",
        help=
        'Path to file containing validation annotations (optional, see readme)'
    )
    parser.add_argument('--voc_train',
                        default="./data/voc_train",
                        help='Path to containing images and annAnnotations')
    parser.add_argument('--voc_val',
                        default="./data/bov_train",
                        help='Path to containing images and annAnnotations')
    parser.add_argument(
        '--depth',
        help='Resnet depth, must be one of 18, 34, 50, 101, 152',
        type=int,
        default=101)
    parser.add_argument('--epochs',
                        help='Number of epochs',
                        type=int,
                        default=40)

    parser = parser.parse_args(args)
    # Create the data loaders
    if parser.dataset == 'coco':

        if parser.coco_path is None:
            raise ValueError('Must provide --coco_path when training on COCO,')

        dataset_train = CocoDataset(parser.coco_path,
                                    set_name='train2017',
                                    transform=transforms.Compose(
                                        [Normalizer(),
                                         Augmenter(),
                                         Resizer()]))
        dataset_val = CocoDataset(parser.coco_path,
                                  set_name='val2017',
                                  transform=transforms.Compose(
                                      [Normalizer(), Resizer()]))

    elif parser.dataset == 'csv':

        if parser.csv_train is None:
            raise ValueError('Must provide --csv_train when training on COCO,')

        if parser.csv_classes is None:
            raise ValueError(
                'Must provide --csv_classes when training on COCO,')

        dataset_train = CSVDataset(train_file=parser.csv_train,
                                   class_list=parser.csv_classes,
                                   transform=transforms.Compose(
                                       [Normalizer(),
                                        Augmenter(),
                                        Resizer()]))

        if parser.csv_val is None:
            dataset_val = None
            print('No validation annotations provided.')
        else:
            dataset_val = CSVDataset(train_file=parser.csv_val,
                                     class_list=parser.csv_classes,
                                     transform=transforms.Compose(
                                         [Normalizer(),
                                          Resizer()]))
    elif parser.dataset == 'voc':
        if parser.voc_train is None:
            raise ValueError(
                'Must provide --voc_train when training on PASCAL VOC,')
        dataset_train = XML_VOCDataset(
            img_path=parser.voc_train + 'JPEGImages/',
            xml_path=parser.voc_train + 'Annotations/',
            class_list=class_list,
            transform=transforms.Compose(
                [Normalizer(), Augmenter(),
                 ResizerMultiScale()]))

        if parser.voc_val is None:
            dataset_val = None
            print('No validation annotations provided.')
        else:
            dataset_val = XML_VOCDataset(
                img_path=parser.voc_val + 'JPEGImages/',
                xml_path=parser.voc_val + 'Annotations/',
                class_list=class_list,
                transform=transforms.Compose([Normalizer(),
                                              Resizer()]))

    else:
        raise ValueError(
            'Dataset type not understood (must be csv or coco), exiting.')

    sampler = AspectRatioBasedSampler(dataset_train,
                                      batch_size=1,
                                      drop_last=False)
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=2,
                                  collate_fn=collater,
                                  batch_sampler=sampler)

    if dataset_val is not None:
        sampler_val = AspectRatioBasedSampler(dataset_val,
                                              batch_size=1,
                                              drop_last=False)
        dataloader_val = DataLoader(dataset_val,
                                    num_workers=2,
                                    collate_fn=collater,
                                    batch_sampler=sampler_val)

    # Create the model
    if parser.depth == 18:
        retinanet = model.resnet18(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 34:
        retinanet = model.resnet34(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 50:
        retinanet = model.resnet50(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 101:
        retinanet = model.resnet101(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    elif parser.depth == 152:
        retinanet = model.resnet152(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    else:
        raise ValueError(
            'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    use_gpu = True

    if use_gpu:
        retinanet = retinanet.cuda()

    retinanet = torch.nn.DataParallel(retinanet).cuda()

    retinanet.training = True

    optimizer = optim.Adam(retinanet.parameters(), lr=1e-4)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=15,
                                                     verbose=True,
                                                     mode="max")
    #scheduler = optim.lr_scheduler.StepLR(optimizer,8)
    loss_hist = collections.deque(maxlen=1024)

    retinanet.train()
    retinanet.module.freeze_bn()
    if not os.path.exists("./logs"):
        os.mkdir("./logs")
    log_file = open("./logs/log.txt", "w")
    print('Num training images: {}'.format(len(dataset_train)))
    best_map = 0
    print("Training models...")
    for epoch_num in range(parser.epochs):

        #scheduler.step(epoch_num)
        retinanet.train()
        retinanet.module.freeze_bn()

        epoch_loss = []

        for iter_num, data in enumerate(dataloader_train):
            #print('iter num is: ', iter_num)
            try:
                #print(csv_eval.evaluate(dataset_val[:20], retinanet)[0])
                #print(type(csv_eval.evaluate(dataset_val, retinanet)))
                #print('iter num is: ', iter_num % 10 == 0)
                optimizer.zero_grad()

                classification_loss, regression_loss = retinanet(
                    [data['img'].cuda().float(), data['annot']])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()

                loss = classification_loss + regression_loss
                #print(loss)

                if bool(loss == 0):
                    continue

                loss.backward()

                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)

                optimizer.step()

                loss_hist.append(float(loss))

                epoch_loss.append(float(loss))
                if iter_num % 50 == 0:
                    print(
                        'Epoch: {} | Iteration: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f}'
                        .format(epoch_num, iter_num,
                                float(classification_loss),
                                float(regression_loss), np.mean(loss_hist)))
                    log_file.write(
                        'Epoch: {} | Iteration: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f} \n'
                        .format(epoch_num, iter_num,
                                float(classification_loss),
                                float(regression_loss), np.mean(loss_hist)))
                del classification_loss
                del regression_loss
            except Exception as e:
                print(e)
                continue

        if parser.dataset == 'coco':

            print('Evaluating dataset')

            coco_eval.evaluate_coco(dataset_val, retinanet)

        elif parser.dataset == 'csv' and parser.csv_val is not None:

            print('Evaluating dataset')

            mAP = csv_eval.evaluate(dataset_val, retinanet)
        elif parser.dataset == 'voc' and parser.voc_val is not None:

            print('Evaluating dataset')

            mAP = voc_eval.evaluate(dataset_val, retinanet)

        try:
            is_best_map = mAP[0][0] > best_map
            best_map = max(mAP[0][0], best_map)
        except:
            pass
        if is_best_map:
            print("Get better map: ", best_map)

            torch.save(retinanet.module,
                       './logs/{}_scale15_{}.pt'.format(epoch_num, best_map))
            shutil.copyfile(
                './logs/{}_scale15_{}.pt'.format(epoch_num, best_map),
                "./best_models/model.pt")
        else:
            print("Current map: ", best_map)
        scheduler.step(best_map)
    retinanet.eval()

    torch.save(retinanet, './logs/model_final.pt')
Пример #27
0
def main(args=None):

    parser = argparse.ArgumentParser(
        description='Simple training script for training a RetinaNet network.')

    parser.add_argument('--dataset',
                        help='Dataset type, must be one of csv or coco.',
                        default="csv")
    parser.add_argument('--coco_path', help='Path to COCO directory')
    parser.add_argument(
        '--csv_train',
        help='Path to file containing training annotations (see readme)')
    parser.add_argument('--csv_classes',
                        help='Path to file containing class list (see readme)',
                        default="binary_class.csv")
    parser.add_argument(
        '--csv_val',
        help=
        'Path to file containing validation annotations (optional, see readme)'
    )

    parser.add_argument(
        '--depth',
        help='Resnet depth, must be one of 18, 34, 50, 101, 152',
        type=int,
        default=18)
    parser.add_argument('--epochs',
                        help='Number of epochs',
                        type=int,
                        default=500)
    parser.add_argument('--epochs_only_det',
                        help='Number of epochs to train detection part',
                        type=int,
                        default=1)
    parser.add_argument('--max_epochs_no_improvement',
                        help='Max epochs without improvement',
                        type=int,
                        default=100)
    parser.add_argument('--pretrained_model',
                        help='Path of .pt file with pretrained model',
                        default='esposallescsv_retinanet_0.pt')
    parser.add_argument('--model_out',
                        help='Path of .pt file with trained model to save',
                        default='trained')

    parser.add_argument('--score_threshold',
                        help='Score above which boxes are kept',
                        type=float,
                        default=0.5)
    parser.add_argument('--nms_threshold',
                        help='Score above which boxes are kept',
                        type=float,
                        default=0.2)
    parser.add_argument('--max_boxes',
                        help='Max boxes to be fed to recognition',
                        default=95)
    parser.add_argument('--seg_level',
                        help='[line, word], to choose anchor aspect ratio',
                        default='word')
    parser.add_argument(
        '--early_stop_crit',
        help='Early stop criterion, detection (map) or transcription (cer)',
        default='cer')
    parser.add_argument('--max_iters_epoch',
                        help='Max steps per epoch (for debugging)',
                        default=1000000)
    parser.add_argument('--train_htr',
                        help='Train recognition or not',
                        default='True')
    parser.add_argument('--train_det',
                        help='Train detection or not',
                        default='True')
    parser.add_argument(
        '--binary_classifier',
        help=
        'Wether to use classification branch as binary or not, multiclass instead.',
        default='False')
    parser.add_argument(
        '--htr_gt_box',
        help='Train recognition branch with box gt (for debugging)',
        default='False')
    parser.add_argument(
        '--ner_branch',
        help='Train named entity recognition with separate branch',
        default='False')

    parser = parser.parse_args(args)

    if parser.dataset == 'csv':

        if parser.csv_train is None:
            raise ValueError('Must provide --csv_train')

        dataset_name = parser.csv_train.split("/")[-2]

        dataset_train = CSVDataset(train_file=parser.csv_train,
                                   class_list=parser.csv_classes,
                                   transform=transforms.Compose(
                                       [Normalizer(),
                                        Augmenter(),
                                        Resizer()]))

        if parser.csv_val is None:
            dataset_val = None
            print('No validation annotations provided.')
        else:
            dataset_val = CSVDataset(train_file=parser.csv_val,
                                     class_list=parser.csv_classes,
                                     transform=transforms.Compose(
                                         [Normalizer(),
                                          Resizer()]))

    else:
        raise ValueError(
            'Dataset type not understood (must be csv or coco), exiting.')

    # Files for training log

    experiment_id = str(time.time()).split('.')[0]
    valid_cer_f = open('trained_models/' + parser.model_out + 'log.txt', 'w')
    for arg in vars(parser):
        if getattr(parser, arg) is not None:
            valid_cer_f.write(
                str(arg) + ' ' + str(getattr(parser, arg)) + '\n')

    current_commit = subprocess.check_output(['git', 'rev-parse', 'HEAD'])
    valid_cer_f.write(str(current_commit))

    valid_cer_f.write(
        "epoch_num   cer     best cer     mAP    best mAP     time\n")

    valid_cer_f.close()

    sampler = AspectRatioBasedSampler(dataset_train,
                                      batch_size=1,
                                      drop_last=False)
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=3,
                                  collate_fn=collater,
                                  batch_sampler=sampler)

    if dataset_val is not None:
        sampler_val = AspectRatioBasedSampler(dataset_val,
                                              batch_size=1,
                                              drop_last=False)
        dataloader_val = DataLoader(dataset_val,
                                    num_workers=0,
                                    collate_fn=collater,
                                    batch_sampler=sampler_val)

    if not os.path.exists('trained_models'):
        os.mkdir('trained_models')

    # Create the model

    train_htr = parser.train_htr == 'True'
    htr_gt_box = parser.htr_gt_box == 'True'
    ner_branch = parser.ner_branch == 'True'
    binary_classifier = parser.binary_classifier == 'True'
    torch.backends.cudnn.benchmark = False

    alphabet = dataset_train.alphabet
    if os.path.exists(parser.pretrained_model):
        retinanet = torch.load(parser.pretrained_model)
        retinanet.classificationModel = ClassificationModel(
            num_features_in=256,
            num_anchors=retinanet.anchors.num_anchors,
            num_classes=dataset_train.num_classes())
        if ner_branch:
            retinanet.nerModel = NERModel(
                feature_size=256,
                pool_h=retinanet.pool_h,
                n_classes=dataset_train.num_classes(),
                pool_w=retinanet.pool_w)
    else:
        if parser.depth == 18:
            retinanet = model.resnet18(num_classes=dataset_train.num_classes(),
                                       pretrained=True,
                                       max_boxes=int(parser.max_boxes),
                                       score_threshold=float(
                                           parser.score_threshold),
                                       seg_level=parser.seg_level,
                                       alphabet=alphabet,
                                       train_htr=train_htr,
                                       htr_gt_box=htr_gt_box,
                                       ner_branch=ner_branch,
                                       binary_classifier=binary_classifier)

        elif parser.depth == 34:

            retinanet = model.resnet34(num_classes=dataset_train.num_classes(),
                                       pretrained=True,
                                       max_boxes=int(parser.max_boxes),
                                       score_threshold=float(
                                           parser.score_threshold),
                                       seg_level=parser.seg_level,
                                       alphabet=alphabet,
                                       train_htr=train_htr,
                                       htr_gt_box=htr_gt_box)

        elif parser.depth == 50:
            retinanet = model.resnet50(num_classes=dataset_train.num_classes(),
                                       pretrained=True)
        elif parser.depth == 101:
            retinanet = model.resnet101(
                num_classes=dataset_train.num_classes(), pretrained=True)
        elif parser.depth == 152:
            retinanet = model.resnet152(
                num_classes=dataset_train.num_classes(), pretrained=True)
        else:
            raise ValueError(
                'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    use_gpu = True
    train_htr = parser.train_htr == 'True'
    train_det = parser.train_det == 'True'
    retinanet.htr_gt_box = parser.htr_gt_box == 'True'

    retinanet.train_htr = train_htr
    retinanet.epochs_only_det = parser.epochs_only_det

    if use_gpu:
        retinanet = retinanet.cuda()

    retinanet = torch.nn.DataParallel(retinanet).cuda()

    retinanet.training = True

    optimizer = optim.Adam(retinanet.parameters(), lr=1e-4)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=50,
                                                     verbose=True)

    loss_hist = collections.deque(maxlen=500)
    ctc = CTCLoss()
    retinanet.train()
    retinanet.module.freeze_bn()

    best_cer = 1000
    best_map = 0
    epochs_no_improvement = 0
    verbose_each = 20
    optimize_each = 1
    objective = 100
    best_objective = 10000

    print(('Num training images: {}'.format(len(dataset_train))))

    for epoch_num in range(parser.epochs):
        cers = []

        retinanet.training = True

        retinanet.train()
        retinanet.module.freeze_bn()

        epoch_loss = []

        for iter_num, data in enumerate(dataloader_train):
            if iter_num > int(parser.max_iters_epoch): break
            try:
                if iter_num % optimize_each == 0:
                    optimizer.zero_grad()
                (classification_loss, regression_loss, ctc_loss,
                 ner_loss) = retinanet([
                     data['img'].cuda().float(), data['annot'], ctc, epoch_num
                 ])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()
                if train_det:

                    if train_htr:
                        loss = ctc_loss + classification_loss + regression_loss + ner_loss

                    else:
                        loss = classification_loss + regression_loss + ner_loss

                elif train_htr:
                    loss = ctc_loss

                else:
                    continue
                if bool(loss == 0):
                    continue
                loss.backward()
                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)
                if iter_num % verbose_each == 0:
                    print((
                        'Epoch: {} | Step: {} |Classification loss: {:1.5f} | Regression loss: {:1.5f} | CTC loss: {:1.5f} | NER loss: {:1.5f} | Running loss: {:1.5f} | Total loss: {:1.5f}\r'
                        .format(epoch_num, iter_num,
                                float(classification_loss),
                                float(regression_loss), float(ctc_loss),
                                float(ner_loss), np.mean(loss_hist),
                                float(loss), "\r")))

                optimizer.step()

                loss_hist.append(float(loss))

                epoch_loss.append(float(loss))
                torch.cuda.empty_cache()

            except Exception as e:
                print(e)
                continue
        if parser.dataset == 'csv' and parser.csv_val is not None and train_det:

            print('Evaluating dataset')

            mAP, text_mAP, current_cer = csv_eval.evaluate(
                dataset_val, retinanet, score_threshold=parser.score_threshold)
            #text_mAP,_ = csv_eval_binary_map.evaluate(dataset_val, retinanet,score_threshold=parser.score_threshold)
            objective = current_cer * (1 - mAP)

        retinanet.eval()
        retinanet.training = False
        retinanet.score_threshold = float(parser.score_threshold)
        '''for idx,data in enumerate(dataloader_val):
            if idx>int(parser.max_iters_epoch): break
            print("Eval CER on validation set:",idx,"/",len(dataset_val),"\r")
            image_name = dataset_val.image_names[idx].split('/')[-1].split('.')[-2]

            #generate_pagexml(image_name,data,retinanet,parser.score_threshold,parser.nms_threshold,dataset_val)
            text_gt =".".join(dataset_val.image_names[idx].split('.')[:-1])+'.txt'
            f =open(text_gt,'r')
            text_gt_lines=f.readlines()[0]
            transcript_pred = get_transcript(image_name,data,retinanet,float(parser.score_threshold),float(parser.nms_threshold),dataset_val,alphabet)
            cers.append(float(editdistance.eval(transcript_pred,text_gt_lines))/len(text_gt_lines))'''

        t = str(time.time()).split('.')[0]

        valid_cer_f.close()
        #print("GT",text_gt_lines)
        #print("PREDS SAMPLE:",transcript_pred)

        if parser.early_stop_crit == 'cer':

            if float(objective) < float(
                    best_objective):  #float(current_cer)<float(best_cer):
                best_cer = current_cer
                best_objective = objective

                epochs_no_improvement = 0
                torch.save(
                    retinanet.module, 'trained_models/' + parser.model_out +
                    '{}_retinanet.pt'.format(parser.dataset))

            else:
                epochs_no_improvement += 1
            if mAP > best_map:
                best_map = mAP
        elif parser.early_stop_crit == 'map':
            if mAP > best_map:
                best_map = mAP
                epochs_no_improvement = 0
                torch.save(
                    retinanet.module, 'trained_models/' + parser.model_out +
                    '{}_retinanet.pt'.format(parser.dataset))

            else:
                epochs_no_improvement += 1
            if float(current_cer) < float(best_cer):
                best_cer = current_cer
        if train_det:
            print(epoch_num, "mAP: ", mAP, " best mAP", best_map)
        if train_htr:
            print("VALID CER:", current_cer, "best CER", best_cer)
        print("Epochs no improvement:", epochs_no_improvement)
        valid_cer_f = open('trained_models/' + parser.model_out + 'log.txt',
                           'a')
        valid_cer_f.write(
            str(epoch_num) + " " + str(current_cer) + " " + str(best_cer) +
            ' ' + str(mAP) + ' ' + str(best_map) + ' ' + str(text_mAP) + '\n')
        if epochs_no_improvement > 3:
            for param_group in optimizer.param_groups:
                if param_group['lr'] > 10e-5:
                    param_group['lr'] *= 0.1

        if epochs_no_improvement >= parser.max_epochs_no_improvement:
            print("TRAINING FINISHED AT EPOCH", epoch_num, ".")
            sys.exit()

        scheduler.step(np.mean(epoch_loss))
        torch.cuda.empty_cache()

    retinanet.eval()
Пример #28
0
# write dict into json file
json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

# batch_size for green is 32 for black is 21 for oubao is 30
val_data_gen = train_image_generator.flow_from_directory(
    directory=validation_dir,
    batch_size=30,
    shuffle=True,
    target_size=(im_height, im_width),
    class_mode='categorical')
# img, _ = next(train_data_gen)
total_val = val_data_gen.n

model = resnet101(num_classes=2, include_top=True)
model.summary()

# using keras low level api for training
loss_object = tf.keras.losses.CategoricalCrossentropy(from_logits=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0002)

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
train_auc = tf.keras.metrics.Mean(name='train_auc')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy(name='test_accuracy')
test_auc = tf.keras.metrics.Mean(name='test_tp')

Пример #29
0
def run_train(train_verbose=False):
    dataset = Dataset(opt)
    dataloader = data_.DataLoader(dataset, \
                                      batch_size=opt.batch_size, \
                                      shuffle=True, \
                                      # pin_memory=True,
                                      num_workers=opt.num_workers)

    testset = TestDataset(opt)
    test_dataloader = data_.DataLoader(testset,
                                       batch_size=opt.batch_size,
                                       num_workers=opt.num_workers,
                                       shuffle=False#, \
                                       #pin_memory=True
                                       )

    resnet = model.resnet101(20,True)
    resnet = torch.nn.DataParallel(resnet).cuda()

    optimizer = optim.Adam(resnet.parameters(), lr=opt.lr)

    loss_hist = collections.deque(maxlen=500)
    epoch_loss_hist = []
    resnet_trainer = Trainer(resnet,optimizer,model_name=opt.model_name)

    freeze_num = 8
    best_map = 0
    best_map_epoch_num = 0
    num_bad_epochs = 0
    max_bad_epochs = 5
    resnet_trainer.model_freeze(freeze_num=freeze_num)

    for epoch_num in range(opt.epoch):
        resnet_trainer.train_mode(freeze_num)
        train_start_time = time.time()
        train_epoch_loss = []
        start = time.time()
        for iter_num, data in enumerate(dataloader):
            curr_loss = resnet_trainer.train_step(data)
            loss_hist.append(float(curr_loss))
            train_epoch_loss.append(float(curr_loss))

            if (train_verbose):
                print('Epoch: {} | Iteration: {} | loss: {:1.5f} | Running loss: {:1.5f} | Iter time: {:1.5f} | Train'
                      ' time: {:1.5f}'.format(epoch_num, iter_num, float(curr_loss), np.mean(loss_hist),
                       time.time()-start, time.time()-train_start_time))
                start = time.time()

            del curr_loss
        print('train epoch time :', time.time() - train_start_time)
        print('Epoch: {} | epoch train loss: {:1.5f}'.format(
            epoch_num, np.mean(train_epoch_loss)))

        vali_start_time = time.time()
        # vali_epoch_loss = []
        # for iter_num, data in enumerate(test_dataloader):
        #     curr_loss = resnet_trainer.get_loss(data)
        #     vali_epoch_loss.append(float(curr_loss))
        #
        #     del curr_loss
        #
        # epoch_loss_hist.append(np.mean(vali_epoch_loss))
        resnet_trainer.eval_mode()
        vali_eval_result = resnet_trainer.run_eval(test_dataloader)
        print(vali_eval_result)
        print('vali epoch time :', time.time() - vali_start_time)
        # print('Epoch: {} | epoch vali loss: {:1.5f}'.format(
        #     epoch_num, np.mean(vali_epoch_loss)))
        #
        if (best_map > vali_eval_result['map']):
            num_bad_epochs += 1
        else:
            best_map = vali_eval_result['map']
            best_map_epoch_num = epoch_num
            num_bad_epochs = 0
            resnet_trainer.model_save(epoch_num)
        if (num_bad_epochs > max_bad_epochs):
            num_bad_epochs = 0
            resnet_trainer.model_load(best_map_epoch_num)
            resnet_trainer.reduce_lr(factor=0.1, verbose=True)
            resnet_trainer.model_freeze(freeze_num=0)

        print('best epoch num', best_map_epoch_num)
        print('----------------------------------------')

    print(epoch_loss_hist)
Пример #30
0
def main(args=None):

    parser = argparse.ArgumentParser(
        description='Simple training script for training a CTracker network.')

    parser.add_argument('--dataset',
                        default='csv',
                        type=str,
                        help='Dataset type, must be one of csv or coco.')
    parser.add_argument('--model_dir',
                        default='./ctracker/',
                        type=str,
                        help='Path to save the model.')
    parser.add_argument(
        '--root_path',
        default='/Dataset/Tracking/MOT17/',
        type=str,
        help='Path of the directory containing both label and images')
    parser.add_argument(
        '--csv_train',
        default='train_annots.csv',
        type=str,
        help='Path to file containing training annotations (see readme)')
    parser.add_argument('--csv_classes',
                        default='train_labels.csv',
                        type=str,
                        help='Path to file containing class list (see readme)')

    parser.add_argument(
        '--depth',
        help='Resnet depth, must be one of 18, 34, 50, 101, 152',
        type=int,
        default=50)
    parser.add_argument('--epochs',
                        help='Number of epochs',
                        type=int,
                        default=100)
    parser.add_argument('--print_freq',
                        help='Print frequency',
                        type=int,
                        default=100)
    parser.add_argument(
        '--save_every',
        help='Save a checkpoint of model at given interval of epochs',
        type=int,
        default=5)

    parser = parser.parse_args(args)
    print(parser)

    print(parser.model_dir)
    if not os.path.exists(parser.model_dir):
        os.makedirs(parser.model_dir)

    # Create the data loaders
    if parser.dataset == 'csv':
        if (parser.csv_train is None) or (parser.csv_train == ''):
            raise ValueError('Must provide --csv_train when training on COCO,')

        if (parser.csv_classes is None) or (parser.csv_classes == ''):
            raise ValueError(
                'Must provide --csv_classes when training on COCO,')

        dataset_train = CSVDataset(parser.root_path, train_file=os.path.join(parser.root_path, parser.csv_train), class_list=os.path.join(parser.root_path, parser.csv_classes), \
         transform=transforms.Compose([RandomSampleCrop(), PhotometricDistort(), Augmenter(), Normalizer()]))#transforms.Compose([Normalizer(), Augmenter(), Resizer()]))

    else:
        raise ValueError(
            'Dataset type not understood (must be csv or coco), exiting.')

    # sampler = AspectRatioBasedSampler(dataset_train, batch_size=2, drop_last=False)
    sampler = AspectRatioBasedSampler(dataset_train,
                                      batch_size=8,
                                      drop_last=False)
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=32,
                                  collate_fn=collater,
                                  batch_sampler=sampler)

    # Create the model
    if parser.depth == 18:
        retinanet = model.resnet18(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 34:
        retinanet = model.resnet34(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 50:
        retinanet = model.resnet50(num_classes=dataset_train.num_classes(),
                                   pretrained=True)
    elif parser.depth == 101:
        retinanet = model.resnet101(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    elif parser.depth == 152:
        retinanet = model.resnet152(num_classes=dataset_train.num_classes(),
                                    pretrained=True)
    else:
        raise ValueError(
            'Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    use_gpu = True

    if use_gpu:
        retinanet = retinanet.cuda()

    retinanet = torch.nn.DataParallel(retinanet).cuda()

    retinanet.training = True

    # optimizer = optim.Adam(retinanet.parameters(), lr=1e-5)
    optimizer = optim.Adam(retinanet.parameters(), lr=5e-5)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=3,
                                                     verbose=True)

    loss_hist = collections.deque(maxlen=500)

    retinanet.train()
    retinanet.module.freeze_bn()

    print('Num training images: {}'.format(len(dataset_train)))
    total_iter = 0
    for epoch_num in range(parser.epochs):

        retinanet.train()
        retinanet.module.freeze_bn()

        epoch_loss = []

        for iter_num, data in enumerate(dataloader_train):
            try:
                total_iter = total_iter + 1
                optimizer.zero_grad()

                (classification_loss, regression_loss), reid_loss = retinanet([
                    data['img'].cuda().float(), data['annot'],
                    data['img_next'].cuda().float(), data['annot_next']
                ])

                classification_loss = classification_loss.mean()
                regression_loss = regression_loss.mean()
                reid_loss = reid_loss.mean()

                # loss = classification_loss + regression_loss + track_classification_losses
                loss = classification_loss + regression_loss + reid_loss

                if bool(loss == 0):
                    continue

                loss.backward()

                torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)

                optimizer.step()

                loss_hist.append(float(loss))
                epoch_loss.append(float(loss))

                # print frequency default=100 or e.g. --print_freq 500
                if total_iter % parser.print_freq == 0:
                    print(
                        'Epoch: {} | Iter: {} | Cls loss: {:1.5f} | Reid loss: {:1.5f} | Reg loss: {:1.5f} | Running loss: {:1.5f}'
                        .format(epoch_num, iter_num,
                                float(classification_loss), float(reid_loss),
                                float(regression_loss), np.mean(loss_hist)))

            except Exception as e:
                print(e)
                continue

        scheduler.step(np.mean(epoch_loss))
        # Save a checkpoint of model at given interval of epochs e.g. --save_every 10
        if epoch_num % parser.save_every == 0:
            torch.save(
                retinanet,
                os.path.join(parser.model_dir,
                             "weights_epoch_" + str(epoch_num) + ".pt"))

    retinanet.eval()

    torch.save(retinanet, os.path.join(parser.model_dir, 'model_final.pt'))
    run_from_train(parser.model_dir, parser.root_path)