def train(total_epochs=1, interval=100, resume=False, ckpt_path = ''): print("Loading training dataset...") train_dset = OpenImagesDataset(root='./data/train', list_file ='./data/tmp/train_images_bbox.csv', transform=transform, train=True, input_size=600) train_loader = data.DataLoader(train_dset, batch_size=4, shuffle=True, num_workers=4, collate_fn=train_dset.collate_fn) print("Loading completed.") #val_dset = OpenImagesDataset(root='./data/train', # list_file='./data/tmp/train_images_bbox.csv', train=False, transform=transform, input_size=600) #val_loader = torch.utils.data.DataLoader(val_dset, batch_size=1, shuffle=False, num_workers=4, collate_fn=val_dset.collate_fn) net = RetinaNet() net.load_state_dict(torch.load('./model/net.pth')) criterion = FocalLoss() net.cuda() criterion.cuda() optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4) best_val_loss = 1000 start_epoch=0 if resume: if os.path.isfile(ckpt_path): print(f'Loading from the checkpoint {ckpt_path}') checkpoint = torch.load(ckpt_path) start_epoch = checkpoint['epoch'] best_val_loss = checkpoint['best_val_loss'] net.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) print(f'Loaded checkpoint {ckpt_path}, epoch : {start_epoch}') else: print(f'No check point found at the path {ckpt_path}') for epoch in range(start_epoch, total_epochs): train_one_epoch(train_loader, net, criterion, optimizer, epoch, interval) val_loss = 0 #val_loss = validate(val_loader, net, criterion, interval) if val_loss < best_val_loss: best_val_loss = val_loss save_checkpoint({ 'epoch': epoch+1, 'state_dict': net.state_dict(), 'best_val_loss': best_val_loss, 'optimizer' : optimizer.state_dict() }, is_best=True)
from fpn import FPN50 from retinanet import RetinaNet print('Loading pretrained ResNet50 model..') d = torch.load('./model/resnet50.pth') print('Loading into FPN50..') fpn = FPN50() dd = fpn.state_dict() for k in d.keys(): if not k.startswith('fc'): # skip fc layers dd[k] = d[k] print('Saving RetinaNet..') net = RetinaNet() for m in net.modules(): if isinstance(m, nn.Conv2d): init.normal(m.weight, mean=0, std=0.01) if m.bias is not None: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() pi = 0.01 init.constant(net.cls_head[-1].bias, -math.log((1 - pi) / pi)) net.fpn.load_state_dict(dd) torch.save(net.state_dict(), 'net.pth') print('Done!')
from retinanet import RetinaNet print('Loading pretrained ResNet50 model..') d = torch.load('./model/resnet50.pth') print('Loading into FPN50..') fpn = FPN50() dd = fpn.state_dict() for k in d.keys(): if not k.startswith('fc'): # skip fc layers dd[k] = d[k] print('Saving RetinaNet..') net = RetinaNet() for m in net.modules(): if isinstance(m, nn.Conv2d): init.normal(m.weight, mean=0, std=0.01) if m.bias is not None: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() pi = 0.01 init.constant(net.cls_head[-1].bias, -math.log((1-pi)/pi)) net.fpn.load_state_dict(dd) torch.save(net.state_dict(), 'net.pth') print('Done!')
print('Loading into FPN50..') fpn = FPN50() dd = fpn.state_dict() for k in d.keys(): #if not k.startswith('fc'): # skip fc layers if 'last_linear' in k: print("break : ", k) break dd[k] = d[k] print('Saving RetinaNet..') net = RetinaNet(1) for m in net.modules(): if isinstance(m, nn.Conv2d): init.normal_(m.weight, mean=0, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() pi = 0.01 init.constant_(net.cls_head[-1].bias, -math.log((1 - pi) / pi)) net.fpn.load_state_dict(dd) torch.save(net.state_dict(), 'weights/retinanet_se50.pth') print('Done!')
print('Loading pretrained ResNet50 model..') # path=os.path.dirname(os.path.dirname(os.path.realpath(__file__))) # print(path) # print(os.path.join(path,'/model/resnet50-19c8e357.pth')) d = torch.load('../model/resnet50.pth') print('Loading into FPN50..') fpn = FPN50() dd = fpn.state_dict() for k in d.keys(): if not k.startswith('fc'): # skip fc layers dd[k] = d[k] print('Saving RetinaNet..') net = RetinaNet() for m in net.modules(): if isinstance(m, nn.Conv2d): init.normal_(m.weight, mean=0, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() pi = 0.01 init.constant_(net.cls_head[-1].bias, -math.log((1 - pi) / pi)) net.fpn.load_state_dict(dd) torch.save(net.state_dict(), '../model/net.pth') print('Done!')
dd[k] = d[k] print('Saving RetinaNet..') net = RetinaNet() for m in net.fpn.modules(): if isinstance(m, nn.Conv2d): init.xavier_uniform(m.weight) if m.bias is not None: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() for m in net.cls_head.modules(): if isinstance(m, nn.Conv2d): init.normal(m.weight, mean=0, std=0.01) if m.bias is not None: init.constant(m.bias, 0) for m in net.loc_head.modules(): if isinstance(m, nn.Conv2d): init.normal(m.weight, mean=0, std=0.01) if m.bias is not None: init.constant(m.bias, 0) pi = 0.01 init.constant(net.cls_head[-1].bias, -math.log((1 - pi) / pi)) net.fpn.load_state_dict(dd) torch.save(net.state_dict(), '../pretrained_model/net.pth') print('Done!')
fpn = FPN101() dd = fpn.state_dict() for k in d.keys(): if not k.startswith('fc'): # skip fc layers dd[k] = d[k] print('Saving RetinaNet..') net = RetinaNet(num_classes=15) for m in net.modules(): if isinstance(m, nn.Conv2d): init.normal(m.weight, mean=0, std=0.01) if m.bias is not None: init.constant(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() pi = 0.01 init.constant(net.cls_head[-1].bias, -math.log((1 - pi) / pi)) net.fpn.load_state_dict(dd) torch.save(net.state_dict(), './model/dota_15c_9ma_101.pth') print('Done!') ''' count=0 for param in fpn.parameters(): count+=1 print(param.requires_grad) print(count) '''
valloader = torch.utils.data.DataLoader(valset, batch_size=cfg.batch_size, shuffle=False, num_workers=10, collate_fn=valset.collate_fn) print('Building model...') net = RetinaNet(backbone=cfg.backbone, num_classes=99) net.cuda() cudnn.benchmark = True if args.resume: print('Resuming from checkpoint..') # checkpoint = torch.load(os.path.join('ckpts', args.exp, '29_ckpt.pth'), map_location='cuda') # checkpoint = torch.load(os.path.join('ckpts', 'efnet4', '29_ckpt.pth'), map_location='cuda') model_state = net.state_dict() checkpoint = torch.load( '/media/grisha/hdd1/icevision/efnet_detector_wo_cls.pth', map_location='cuda') model_state.update(checkpoint) net.load_state_dict(model_state) # net.load_state_dict(checkpoint) # start_epoch = checkpoint['epoch'] # lr = cfg.lr criterion = FocalLoss(99) # optimizer = optim.Adam(net.parameters(), lr=cfg.lr,weight_decay=cfg.weight_decay) optimizer = optim.SGD(net.parameters(), lr=cfg.lr, momentum=cfg.momentum,