print('Loading base network...') model.vgg.load_state_dict(vgg_weights) model.to(device) model.train() mb = MultiBoxEncoder(opt) image_sets = [['2007', 'trainval'], ['2012', 'trainval']] dataset = VOCDetection(opt, image_sets=image_sets, is_train=True) dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, collate_fn=detection_collate, num_workers=4) criterion = MultiBoxLoss(opt.num_classes, opt.neg_radio).to(device) optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) print('start training........') for e in range(opt.epoch): if e % opt.lr_reduce_epoch == 0: adjust_learning_rate(optimizer, opt.gamma, e//opt.lr_reduce_epoch) total_loc_loss = 0 total_cls_loss = 0 total_loss = 0 for i , (img, boxes) in enumerate(dataloader): img = img.to(device) gt_boxes = [] gt_labels = [] for box in boxes: labels = box[:, 4]
mb = MultiBoxEncoder(opt) image_sets = [['2007', 'trainval'], ['2012', 'trainval']] dataset = CustomDetection( opt, '/content/gdrive/MyDrive/SSD/VirtualTrafficSignDetectionDB', dbtype='train') dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size, collate_fn=detection_collate, num_workers=4) criterion = MultiBoxLoss(opt.num_classes, opt.neg_radio).to(device) optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) print('start training........') for e in range(opt.epoch): if e % opt.lr_reduce_epoch == 0: adjust_learning_rate(optimizer, opt.gamma, e // opt.lr_reduce_epoch) total_loc_loss = 0 total_cls_loss = 0 total_loss = 0 for i, (img, boxes) in enumerate(dataloader): img = img.to(device) gt_boxes = []