Example #1
0
def train(total_epochs=1, interval=100, resume=False, ckpt_path = ''):
    print("Loading training dataset...")
    train_dset = OpenImagesDataset(root='./data/train',
                            list_file ='./data/tmp/train_images_bbox.csv',
                            transform=transform, train=True, input_size=600)

    train_loader = data.DataLoader(train_dset, batch_size=4, shuffle=True, num_workers=4, collate_fn=train_dset.collate_fn)
    
    print("Loading completed.")

    #val_dset = OpenImagesDataset(root='./data/train',
    #                  list_file='./data/tmp/train_images_bbox.csv', train=False, transform=transform, input_size=600)
    #val_loader = torch.utils.data.DataLoader(val_dset, batch_size=1, shuffle=False, num_workers=4, collate_fn=val_dset.collate_fn)

    net = RetinaNet()
    net.load_state_dict(torch.load('./model/net.pth'))

    criterion = FocalLoss()
    
    net.cuda()
    criterion.cuda()
    optimizer = optim.SGD(net.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
    best_val_loss = 1000

    start_epoch=0

    if resume:
        if os.path.isfile(ckpt_path):
            print(f'Loading from the checkpoint {ckpt_path}')
            checkpoint = torch.load(ckpt_path)
            start_epoch = checkpoint['epoch']
            best_val_loss = checkpoint['best_val_loss']
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(f'Loaded checkpoint {ckpt_path}, epoch : {start_epoch}')
        else:
            print(f'No check point found at the path {ckpt_path}')

    

    for epoch in range(start_epoch, total_epochs):
        train_one_epoch(train_loader, net, criterion, optimizer, epoch, interval)
        val_loss = 0
        #val_loss = validate(val_loader, net, criterion, interval)

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            save_checkpoint({
                'epoch': epoch+1,
                'state_dict': net.state_dict(),
                'best_val_loss': best_val_loss,
                'optimizer' : optimizer.state_dict()
            }, is_best=True)
Example #2
0
from fpn import FPN50
from retinanet import RetinaNet

print('Loading pretrained ResNet50 model..')
d = torch.load('./model/resnet50.pth')

print('Loading into FPN50..')
fpn = FPN50()
dd = fpn.state_dict()
for k in d.keys():
    if not k.startswith('fc'):  # skip fc layers
        dd[k] = d[k]

print('Saving RetinaNet..')
net = RetinaNet()
for m in net.modules():
    if isinstance(m, nn.Conv2d):
        init.normal(m.weight, mean=0, std=0.01)
        if m.bias is not None:
            init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

pi = 0.01
init.constant(net.cls_head[-1].bias, -math.log((1 - pi) / pi))

net.fpn.load_state_dict(dd)
torch.save(net.state_dict(), 'net.pth')
print('Done!')
from retinanet import RetinaNet


print('Loading pretrained ResNet50 model..')
d = torch.load('./model/resnet50.pth')

print('Loading into FPN50..')
fpn = FPN50()
dd = fpn.state_dict()
for k in d.keys():
    if not k.startswith('fc'):  # skip fc layers
        dd[k] = d[k]

print('Saving RetinaNet..')
net = RetinaNet()
for m in net.modules():
    if isinstance(m, nn.Conv2d):
        init.normal(m.weight, mean=0, std=0.01)
        if m.bias is not None:
            init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

pi = 0.01
init.constant(net.cls_head[-1].bias, -math.log((1-pi)/pi))

net.fpn.load_state_dict(dd)
torch.save(net.state_dict(), 'net.pth')
print('Done!')
print('Loading into FPN50..')
fpn = FPN50()
dd = fpn.state_dict()

for k in d.keys():
    #if not k.startswith('fc'):  # skip fc layers
    if 'last_linear' in k:
        print("break : ", k)
        break
    dd[k] = d[k]

print('Saving RetinaNet..')
net = RetinaNet(1)

for m in net.modules():
    if isinstance(m, nn.Conv2d):
        init.normal_(m.weight, mean=0, std=0.01)
        if m.bias is not None:
            init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

pi = 0.01
init.constant_(net.cls_head[-1].bias, -math.log((1 - pi) / pi))

net.fpn.load_state_dict(dd)
torch.save(net.state_dict(), 'weights/retinanet_se50.pth')
print('Done!')
Example #5
0
print('Loading pretrained ResNet50 model..')
# path=os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
# print(path)
# print(os.path.join(path,'/model/resnet50-19c8e357.pth'))
d = torch.load('../model/resnet50.pth')

print('Loading into FPN50..')
fpn = FPN50()
dd = fpn.state_dict()
for k in d.keys():
    if not k.startswith('fc'):  # skip fc layers
        dd[k] = d[k]

print('Saving RetinaNet..')
net = RetinaNet()
for m in net.modules():
    if isinstance(m, nn.Conv2d):
        init.normal_(m.weight, mean=0, std=0.01)
        if m.bias is not None:
            init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

pi = 0.01
init.constant_(net.cls_head[-1].bias, -math.log((1 - pi) / pi))

net.fpn.load_state_dict(dd)
torch.save(net.state_dict(), '../model/net.pth')
print('Done!')
        dd[k] = d[k]

print('Saving RetinaNet..')
net = RetinaNet()
for m in net.fpn.modules():
    if isinstance(m, nn.Conv2d):
        init.xavier_uniform(m.weight)
        if m.bias is not None:
            init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

for m in net.cls_head.modules():
    if isinstance(m, nn.Conv2d):
        init.normal(m.weight, mean=0, std=0.01)
        if m.bias is not None:
            init.constant(m.bias, 0)

for m in net.loc_head.modules():
    if isinstance(m, nn.Conv2d):
        init.normal(m.weight, mean=0, std=0.01)
        if m.bias is not None:
            init.constant(m.bias, 0)

pi = 0.01
init.constant(net.cls_head[-1].bias, -math.log((1 - pi) / pi))

net.fpn.load_state_dict(dd)
torch.save(net.state_dict(), '../pretrained_model/net.pth')
print('Done!')
Example #7
0
fpn = FPN101()
dd = fpn.state_dict()
for k in d.keys():
    if not k.startswith('fc'):  # skip fc layers
        dd[k] = d[k]

print('Saving RetinaNet..')
net = RetinaNet(num_classes=15)
for m in net.modules():
    if isinstance(m, nn.Conv2d):
        init.normal(m.weight, mean=0, std=0.01)
        if m.bias is not None:
            init.constant(m.bias, 0)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

pi = 0.01
init.constant(net.cls_head[-1].bias, -math.log((1 - pi) / pi))

net.fpn.load_state_dict(dd)
torch.save(net.state_dict(), './model/dota_15c_9ma_101.pth')
print('Done!')
'''
count=0
for param in fpn.parameters():
    count+=1
    print(param.requires_grad)
    print(count)
'''
Example #8
0
valloader = torch.utils.data.DataLoader(valset,
                                        batch_size=cfg.batch_size,
                                        shuffle=False,
                                        num_workers=10,
                                        collate_fn=valset.collate_fn)

print('Building model...')
net = RetinaNet(backbone=cfg.backbone, num_classes=99)
net.cuda()
cudnn.benchmark = True

if args.resume:
    print('Resuming from checkpoint..')
    # checkpoint = torch.load(os.path.join('ckpts', args.exp, '29_ckpt.pth'), map_location='cuda')
    # checkpoint = torch.load(os.path.join('ckpts', 'efnet4', '29_ckpt.pth'), map_location='cuda')
    model_state = net.state_dict()
    checkpoint = torch.load(
        '/media/grisha/hdd1/icevision/efnet_detector_wo_cls.pth',
        map_location='cuda')
    model_state.update(checkpoint)
    net.load_state_dict(model_state)
    # net.load_state_dict(checkpoint)
    # start_epoch = checkpoint['epoch']
    # lr = cfg.lr

criterion = FocalLoss(99)

# optimizer = optim.Adam(net.parameters(), lr=cfg.lr,weight_decay=cfg.weight_decay)
optimizer = optim.SGD(net.parameters(),
                      lr=cfg.lr,
                      momentum=cfg.momentum,