def get_dataloader(net, train_dataset, val_dataset, train_transform, val_transform, batch_size,
                   num_workers):
    """Get dataloader."""
    train_bfn = batchify.Tuple(*[batchify.Stack() for _ in range(6)])
    train_loader = mx.gluon.data.DataLoader(
        train_dataset.transform(train_transform(
            net.short, net.base_stride, net.valid_range)),
            batch_size,shuffle=True, batchify_fn=train_bfn, last_batch='rollover',
            num_workers=num_workers)
    val_bfn = batchify.Tuple(*[batchify.Stack() for _ in range(2)])
    val_loader = mx.gluon.data.DataLoader(
        val_dataset.transform(val_transform(net.short, net.base_stride)),
        batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
    return train_loader, val_loader
Beispiel #2
0
 def predict(self, predictor, img):
     img_batch = batchify.Stack()([img])
     return self.batch_predict(predictor, img_batch)
Beispiel #3
0
def predict_food(im_fname,
                 output_filename,
                 threshold=0.5,
                 print_outputs=False):

    net = model_zoo.get_model('yolo3_darknet53_coco', pretrained=True, ctx=ctx)

    base_classes = [
        'bowl', 'cup', 'banana', 'apple', 'sandwich', 'orange', 'broccoli',
        'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'bottle'
    ]

    net.reset_class(classes=base_classes, reuse_weights=base_classes)

    params_path = './trained_parameters/'
    symbol_file = os.path.join(
        params_path, 'ResNet50_v2_epochs50-lr0.001-wd0.001-symbol.json')
    params_file = os.path.join(
        params_path, 'ResNet50_v2_epochs50-lr0.001-wd0.001-0000.params')

    food_classes = ['borscht', 'lagman', 'manty', 'plov', 'samsy']

    all_classes = base_classes + food_classes

    food_net = nn.SymbolBlock.imports(symbol_file, ['data'],
                                      params_file,
                                      ctx=ctx)

    transform_fn = transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    x, orig_img = gcv.data.transforms.presets.yolo.load_test(im_fname)

    box_ids, scores, bboxes = net.forward(x)

    box_ids_np, scores_np, bboxes_np = box_ids[0].asnumpy(), scores[0].asnumpy(
    ), bboxes[0].asnumpy()

    bowl_mask = ((box_ids_np == 0) & (scores_np > threshold))

    bowl_ids = np.where(bowl_mask)

    bowl_boxes = bboxes_np[bowl_mask.ravel(), :]

    if (len(bowl_boxes) > 0):
        bowl_images = [
            orig_img[int(box[1]):int(box[3]),
                     int(box[0]):int(box[2])] for box in bowl_boxes.tolist()
        ]

        bowl_batch_img = batchify.Stack()([
            transform_fn(mx.nd.array(img)) for img in bowl_images
        ]).copyto(ctx[0])

        food_outputs = mx.nd.softmax(food_net(bowl_batch_img))

        food_scores = food_outputs.max(axis=1)

        food_labels = food_outputs.argmax(axis=1) + len(base_classes)

        all_classes = base_classes + food_classes

        box_ids = np.delete(box_ids_np, bowl_ids[0], axis=0)
        scores = np.delete(scores_np, bowl_ids[0], axis=0)
        bboxes = np.delete(bboxes_np, bowl_ids[0], axis=0)

        box_ids = np.concatenate(
            (box_ids, food_labels.asnumpy().reshape(-1, 1)), axis=0)
        scores = np.concatenate((scores, food_scores.asnumpy().reshape(-1, 1)),
                                axis=0)
        bboxes = np.concatenate((bboxes, bowl_boxes.reshape(-1, 4)), axis=0)

    else:
        bboxes = bboxes[0].asnumpy()
        scores = scores[0].asnumpy()
        box_ids = box_ids[0].asnumpy()

    if (print_outputs):
        confident_mask = (scores >= threshold)
        confident_classes = box_ids[confident_mask]
        confident_scores = scores[confident_mask]
        confident_boxes = bboxes[confident_mask.ravel()]

        for i in range(len(confident_boxes)):
            print(
                f"{all_classes[int(confident_classes[i])]:10} \t {confident_scores[i]:.5f}\t{confident_boxes[i]}"
            )

    utils.viz.plot_bbox(orig_img,
                        bboxes,
                        scores,
                        box_ids,
                        class_names=all_classes,
                        thresh=threshold)

    # plt.rc('figure', figsize=(20,20))
    fig = plt.gcf()
    fig.set_size_inches(15, 10)
    plt.axis('off')
    plt.savefig(f"./predictions/{output_filename}",
                dpi=300,
                bbox_inches='tight',
                pad_inches=0)
Beispiel #4
0
def train(net, async_net, ctx, args):
    """Training pipeline"""
    net.collect_params().reset_ctx(ctx)
    if args.no_wd:
        for k, v in net.collect_params(".*beta|.*gamma|.*bias").items():
            v.wd_mult = 0.0

    if args.label_smooth:
        net._target_generator._label_smooth = True

    if args.lr_decay_period > 0:
        lr_decay_epoch = list(
            range(args.lr_decay_period, args.epochs, args.lr_decay_period))
    else:
        lr_decay_epoch = [int(i) for i in args.lr_decay_epoch.split(',')]

    lr_scheduler = LRSequential([
        LRScheduler("linear",
                    base_lr=0,
                    target_lr=args.lr,
                    nepochs=args.warmup_epochs,
                    iters_per_epoch=args.batch_size),
        LRScheduler(args.lr_mode,
                    base_lr=args.lr,
                    nepochs=args.epochs - args.warmup_epochs,
                    iters_per_epoch=args.batch_size,
                    step_epoch=lr_decay_epoch,
                    step_factor=args.lr_decay,
                    power=2),
    ])
    if (args.optimizer == "sgd"):
        trainer = gluon.Trainer(net.collect_params(),
                                args.optimizer, {
                                    "wd": args.wd,
                                    "momentum": args.momentum,
                                    "lr_scheduler": lr_scheduler
                                },
                                kvstore="local")
    elif (args.optimizer == "adam"):
        trainer = gluon.Trainer(net.collect_params(),
                                args.optimizer, {"lr_scheduler": lr_scheduler},
                                kvstore="local")
    else:
        trainer = gluon.Trainer(net.collect_params(),
                                args.optimizer,
                                kvstore="local")

    # targets
    #sigmoid_ce = gluon.loss.SigmoidBinaryCrossEntropyLoss(from_sigmoid=False)
    #l1_loss = gluon.loss.L1Loss()

    # Intermediate Metrics:
    train_metrics = (
        mx.metric.Loss("ObjLoss"),
        mx.metric.Loss("BoxCenterLoss"),
        mx.metric.Loss("BoxScaleLoss"),
        mx.metric.Loss("ClassLoss"),
        mx.metric.Loss("TotalLoss"),
    )
    train_metric_ixs = range(len(train_metrics))
    target_metric_ix = -1  # Train towards TotalLoss (the last one)

    # Evaluation Metrics:
    val_metric = VOC07MApMetric(iou_thresh=0.5)

    # Data transformations:
    train_dataset = gluon_pipe_mode.AugmentedManifestDetection(
        args.train,
        length=args.num_samples_train,
    )
    train_batchify_fn = batchify.Tuple(
        *([batchify.Stack() for _ in range(6)] +
          [batchify.Pad(axis=0, pad_val=-1) for _ in range(1)]))
    if args.no_random_shape:
        logger.debug("Creating train DataLoader without random transform")
        train_transforms = YOLO3DefaultTrainTransform(args.data_shape,
                                                      args.data_shape,
                                                      net=async_net,
                                                      mixup=args.mixup)
        train_dataloader = gluon.data.DataLoader(
            train_dataset.transform(train_transforms),
            batch_size=args.batch_size,
            batchify_fn=train_batchify_fn,
            last_batch="discard",
            num_workers=args.num_workers,
            shuffle=
            False,  # Note that shuffle *cannot* be used with AugmentedManifestDetection
        )
    else:
        logger.debug("Creating train DataLoader with random transform")
        train_transforms = [
            YOLO3DefaultTrainTransform(x * 32,
                                       x * 32,
                                       net=async_net,
                                       mixup=args.mixup)
            for x in range(10, 20)
        ]
        train_dataloader = RandomTransformDataLoader(
            train_transforms,
            train_dataset,
            interval=10,
            batch_size=args.batch_size,
            batchify_fn=train_batchify_fn,
            last_batch="discard",
            num_workers=args.num_workers,
            shuffle=
            False,  # Note that shuffle *cannot* be used with AugmentedManifestDetection
        )
    validation_dataset = None
    validation_dataloader = None
    if args.validation:
        validation_dataset = gluon_pipe_mode.AugmentedManifestDetection(
            args.validation,
            length=args.num_samples_validation,
        )
        validation_dataloader = gluon.data.DataLoader(
            validation_dataset.transform(
                YOLO3DefaultValTransform(args.data_shape, args.data_shape), ),
            args.batch_size,
            shuffle=False,
            batchify_fn=batchify.Tuple(batchify.Stack(),
                                       batchify.Pad(pad_val=-1)),
            last_batch="keep",
            num_workers=args.num_workers,
        )

    # Prepare the inference-time configuration for our model's setup:
    # (This will be saved alongside our network structure/params)
    inference_config = config.InferenceConfig(image_size=args.data_shape)

    logger.info(args)
    logger.info(f"Start training from [Epoch {args.start_epoch}]")
    prev_best_score = float("-inf")
    best_epoch = args.start_epoch
    logger.info("Sleeping for 3s in case training data file not yet ready")
    time.sleep(3)
    for epoch in range(args.start_epoch, args.start_epoch + args.epochs):
        #         if args.mixup:
        #             # TODO(zhreshold): more elegant way to control mixup during runtime
        #             try:
        #                 train_data._dataset.set_mixup(np.random.beta, 1.5, 1.5)
        #             except AttributeError:
        #                 train_data._dataset._data.set_mixup(np.random.beta, 1.5, 1.5)
        #             if epoch >= args.epochs - args.no_mixup_epochs:
        #                 try:
        #                     train_data._dataset.set_mixup(None)
        #                 except AttributeError:
        #                     train_data._dataset._data.set_mixup(None)

        tic = time.time()
        btic = time.time()
        mx.nd.waitall()
        net.hybridize()

        logger.debug(
            f"Input data dir contents: {os.listdir('/opt/ml/input/data/')}")
        for i, batch in enumerate(train_dataloader):
            logger.debug(f"Epoch {epoch}, minibatch {i}")

            batch_size = batch[0].shape[0]
            data = gluon.utils.split_and_load(batch[0],
                                              ctx_list=ctx,
                                              batch_axis=0,
                                              even_split=False)
            # objectness, center_targets, scale_targets, weights, class_targets
            fixed_targets = [
                gluon.utils.split_and_load(batch[it],
                                           ctx_list=ctx,
                                           batch_axis=0,
                                           even_split=False)
                for it in range(1, 6)
            ]
            gt_boxes = gluon.utils.split_and_load(batch[6],
                                                  ctx_list=ctx,
                                                  batch_axis=0,
                                                  even_split=False)
            loss_trackers = tuple([] for metric in train_metrics)
            with autograd.record():
                for ix, x in enumerate(data):
                    losses_raw = net(x, gt_boxes[ix],
                                     *[ft[ix] for ft in fixed_targets])
                    # net outputs: [obj_loss, center_loss, scale_loss, cls_loss]
                    # Each a mx.ndarray 1xbatch_size. This is the same order as our
                    # train_metrics, so we just need to add a total vector:
                    total_loss = sum(losses_raw)
                    losses = losses_raw + [total_loss]

                    # If any sample's total loss is non-finite, sum will be:
                    if not isfinite(sum(total_loss)):
                        logger.error(
                            f"[Epoch {epoch}][Minibatch {i}] got non-finite losses: {losses_raw}"
                        )
                        # TODO: Terminate training if losses or gradient go infinite?

                    for ix in train_metric_ixs:
                        loss_trackers[ix].append(losses[ix])

                autograd.backward(loss_trackers[target_metric_ix])
            trainer.step(batch_size)
            for ix in train_metric_ixs:
                train_metrics[ix].update(0, loss_trackers[ix])

            if args.log_interval and not (i + 1) % args.log_interval:
                train_metrics_current = map(lambda metric: metric.get(),
                                            train_metrics)
                metrics_msg = "; ".join([
                    f"{name}={val:.3f}" for name, val in train_metrics_current
                ])
                logger.info(
                    f"[Epoch {epoch}][Minibatch {i}] LR={trainer.learning_rate:.2E}; "
                    f"Speed={batch_size/(time.time()-btic):.3f} samples/sec; {metrics_msg};"
                )
            btic = time.time()

        train_metrics_current = map(lambda metric: metric.get(), train_metrics)
        metrics_msg = "; ".join(
            [f"{name}={val:.3f}" for name, val in train_metrics_current])
        logger.info(
            f"[Epoch {epoch}] TrainingCost={time.time()-tic:.3f}; {metrics_msg};"
        )

        if not (epoch + 1) % args.val_interval:
            logger.info(f"Validating [Epoch {epoch}]")

            metric_names, metric_values = validate(
                net, validation_dataloader, epoch, ctx,
                VOC07MApMetric(iou_thresh=0.5), args)
            if isinstance(metric_names, list):
                val_msg = "; ".join(
                    [f"{k}={v}" for k, v in zip(metric_names, metric_values)])
                current_score = float(metric_values[-1])
            else:
                val_msg = f"{metric_names}={metric_values}"
                current_score = metric_values
            logger.info(f"[Epoch {epoch}] Validation: {val_msg};")
        else:
            current_score = float("-inf")

        save_progress(
            net,
            inference_config,
            current_score,
            prev_best_score,
            args.model_dir,
            epoch,
            args.checkpoint_interval,
            args.checkpoint_dir,
        )
        if current_score > prev_best_score:
            prev_best_score = current_score
            best_epoch = epoch

        if (args.early_stopping and epoch >= args.early_stopping_min_epochs
                and (epoch - best_epoch) >= args.early_stopping_patience):
            logger.info(
                f"[Epoch {epoch}] No improvement since epoch {best_epoch}: Stopping early"
            )
            break