Beispiel #1
0
test_dset = VOCDetection(root=data_root, split='test')
test_dloader = DataLoader(test_dset,
                          batch_size=batch_size,
                          shuffle=False,
                          drop_last=False,
                          num_workers=8)

model = Yolo(grid_size, num_boxes, num_classes)
model = model.to(device)
pretrained_weights = torch.load(pretrained_backbone_path)
model.load_state_dict(pretrained_weights)
print('loaded pretrained weight')

# Freeze the backbone network.
model.features.requires_grad_(False)
model_params = [v for v in model.parameters() if v.requires_grad is True]
optimizer = optim.SGD(model_params, lr=lr, momentum=0.9, weight_decay=5e-4)
compute_loss = Loss(grid_size, num_boxes, num_classes)

# Load the last checkpoint if exits.
ckpt_path = os.path.join(ckpt_dir, 'last.pth')

if os.path.exists(ckpt_path):
    ckpt = torch.load(ckpt_path)
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optimizer'])
    last_epoch = ckpt['epoch'] + 1
    print('Last checkpoint is loaded. start_epoch:', last_epoch)
else:
    print('No checkpoint is found.')
Beispiel #2
0
    # Get dataloader
    train_dataset = VOCDetection(args.train_path, args.img_size)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              collate_fn=train_dataset.collate_fn)
    val_dataset = VOCDetection(args.val_path, args.img_size)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers,
                            collate_fn=val_dataset.collate_fn)

    optimizer = optim.SGD(model.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)

    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda epoch: 0.98**epoch)

    for epoch in range(args.epochs):

        # model.train()
        # start_time = time.time()

        # for ind, (imgs, targets) in enumerate(train_loader):
        #     print(ind, imgs.shape)
        #     batches_done = len(train_loader) * epoch + ind