示例#1
0
def main():
    # Prepare train dataset and dataloader
    train_ds = PascalVOCDataset('./data', 'TRAIN', keep_difficult=keep_difficult)
    train_loader = torch.utils.data.DataLoader(train_ds,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               collate_fn=train_ds.collate_fn,  # note that we're passing the collate function here
                                               num_workers=num_workers,
                                               pin_memory=True)
    n_classes = len(train_ds.label_map())
    start_epoch = 0

    # Initialize model
    model = SSD300(n_classes=n_classes)

    # Load checkpoint if existed
    checkpoint = None
    if checkpoint_path is not None and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        start_epoch = checkpoint['epoch'] + 1
        print('Load checkpoint from epoch %d.\n' % checkpoint['epoch'])

    if checkpoint is not None:
        model.load_state_dict(checkpoint['model_state_dict'])

    model.to(device)
    model.train()

    # Initialize the optimizer, with twice the default learning rate for biases, as in the original Caffe repo
    biases = list()
    not_biases = list()
    for param_name, param in model.named_parameters():
        if param.requires_grad:
            if param_name.endswith('.bias'):
                biases.append(param)
            else:
                not_biases.append(param)
    optimizer = torch.optim.SGD(params=[{'params': biases, 'lr': 2 * lr}, {'params': not_biases}],
                                lr=lr, momentum=momentum, weight_decay=weight_decay)

    if checkpoint is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    criterion = MultiBoxLoss(priors_cxcy=model.priors_cxcy).to(device)

    # Calculate total number of epochs to train and the epochs to decay learning rate at (i.e. convert iterations to epochs)
    epochs = iterations // (len(train_ds) // batch_size)
    decay_lr_at_epochs = [it // (len(train_ds) // batch_size) for it in decay_lr_at]

    # Epochs
    for epoch in range(start_epoch, epochs):

        # Decay learning rate at particular epochs
        if epoch in decay_lr_at_epochs:
            utils.adjust_learning_rate(optimizer, decay_lr_to)

        # One epoch's training
        train(train_loader, model, criterion, optimizer, epoch)

        # Save checkpoint
        utils.save_checkpoint(checkpoint_path, model, optimizer, epoch)
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Parameters
keep_difficult = True  # difficult ground truth objects must always be considered in mAP calculation, because these objects DO exist!
batch_size = 64
workers = 4
checkpoint_path = 'ssd300.pt'

# Load test data
test_dataset = PascalVOCDataset('./data', split='TEST', keep_difficult=keep_difficult)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                                          collate_fn=test_dataset.collate_fn, num_workers=workers, pin_memory=True)
n_classes = len(test_dataset.label_map())

# Load model checkpoint that is to be evaluated
checkpoint = torch.load(checkpoint_path)
model = SSD300(n_classes)
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

# Switch to eval mode
model.eval()


def evaluate(test_loader, model):
    """
    Evaluate.