示例#1
0
        for x, y in test_dloader:
            # implement testing pipeline here
            x = x.to(device)
            y = y.to(device)

            logit = model(x)
            loss_xy, loss_wh, loss_obj, loss_noobj, loss_class = compute_loss(
                logit, y)
            loss = (lambda_coord * (loss_xy + loss_wh) + loss_obj +
                    lambda_noobj * loss_noobj + loss_class) / batch_size
            valid_loss += loss

        valid_loss /= len(test_dloader)

    ckpt = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch
    }
    torch.save(ckpt, ckpt_path)

VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
               'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
               'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
               'tvmonitor')

test_image_dir = 'test_images'
image_path_list = [
    os.path.join(test_image_dir, path) for path in os.listdir(test_image_dir)
]
示例#2
0
                targets[:, 2:] = utils.xywh2xyxy(targets[:, 2:])
                targets[:, 2:] *= args.img_size

                with torch.no_grad():
                    outputs, _ = model(imgs)
                    outputs = utils.non_max_suppression(
                        outputs,
                        conf_thresh=args.conf_thresh,
                        nms_thresh=args.nms_thresh)

                sample_metrics += utils.get_batch_statistics(
                    outputs, targets, iou_thresh=args.map_thresh)

            if len(sample_metrics) == 0:
                print('---- mAP is NULL')
            else:
                # Concatenate sample statistics
                true_positives, pred_scores, pred_labels = [
                    np.concatenate(x, 0) for x in list(zip(*sample_metrics))
                ]
                precision, recall, AP, f1, ap_class = utils.ap_per_class(
                    true_positives, pred_scores, pred_labels, labels)
                print('---- mAP %.3f' % (AP.mean()))

        if epoch % args.checkpoint_interval == 0 and epoch > 20:
            torch.save(
                model.state_dict(),
                os.path.join(args.output_path,
                             'yolov3_tiny_ckpt_%d.pth' % epoch))

        scheduler.step()