def state_dict(self, destination=None):
     if destination is None:
         destination = OrderedDict()
         destination._metadata = OrderedDict()
     destination['supernet_params'] = collect_params(self.supernet)
     destination['controller_params'] = collect_params(self.controller)
     destination['training_history'] = self.training_history
     return destination
Beispiel #2
0
 def state_dict(self, destination=None):
     if destination is None:
         destination = OrderedDict()
         destination._metadata = OrderedDict()
     model_params = collect_params(self.model)
     destination['model_params'] = model_params
     destination['results'] = pkl.dumps(self.results)
     destination['scheduler_checkpoint'] = self.scheduler_checkpoint
     destination['args'] = self.args
     destination['classes'] = destination['args'].pop(
         'dataset').get_classes()
     return destination
Beispiel #3
0
def train_object_detection(args, reporter):
    # fix seed for mxnet, numpy and python builtin random generator.
    gutils.random.seed(args.seed)

    # training contexts
    ctx = [mx.gpu(i)
           for i in range(args.num_gpus)] if args.num_gpus > 0 else [mx.cpu()]
    if args.meta_arch == 'yolo3':
        net_name = '_'.join((args.meta_arch, args.net, 'custom'))
        kwargs = {}
    elif args.meta_arch == 'faster_rcnn':
        net_name = '_'.join(('custom', args.meta_arch, 'fpn'))
        kwargs = {
            'base_network_name': args.net,
            'short': args.data_shape,
            'max_size': 1000,
            'nms_thresh': 0.5,
            'nms_topk': -1,
            'min_stage': 2,
            'max_stage': 6,
            'post_nms': -1,
            'roi_mode': 'align',
            'roi_size': (7, 7),
            'strides': (4, 8, 16, 32, 64),
            'clip': 4.14,
            'rpn_channel': 256,
            'base_size': 16,
            'scales': (2, 4, 8, 16, 32),
            'ratios': (0.5, 1, 2),
            'alloc_size': (384, 384),
            'rpn_nms_thresh': 0.7,
            'rpn_train_pre_nms': 12000,
            'rpn_train_post_nms': 2000,
            'rpn_test_pre_nms': 6000,
            'rpn_test_post_nms': 1000,
            'rpn_min_size': 1,
            'per_device_batch_size': args.batch_size // args.num_gpus,
            'num_sample': 512,
            'pos_iou_thresh': 0.5,
            'pos_ratio': 0.25,
            'max_num_gt': 100
        }
    else:
        raise NotImplementedError(args.meta_arch, 'is not implemented.')
    args.save_prefix += net_name

    # use sync bn if specified
    if args.syncbn and len(ctx) > 1:
        net = gcv.model_zoo.get_model(
            net_name,
            classes=args.dataset.get_classes(),
            pretrained_base=True,
            transfer=args.transfer,
            norm_layer=gluon.contrib.nn.SyncBatchNorm,
            norm_kwargs={'num_devices': len(ctx)},
            **kwargs)
        if not args.reuse_pred_weights:
            net.reset_class(args.dataset.get_classes(), reuse_weights=None)
        if args.meta_arch == 'yolo3':
            async_net = gcv.model_zoo.get_model(
                net_name,
                classes=args.dataset.get_classes(),
                pretrained_base=True,
                transfer=args.transfer,
                **kwargs)
            if not args.reuse_pred_weights:
                async_net.reset_class(args.dataset.get_classes(),
                                      reuse_weights=None)
    else:
        net = gcv.model_zoo.get_model(net_name,
                                      classes=args.dataset.get_classes(),
                                      pretrained_base=True,
                                      transfer=args.transfer,
                                      **kwargs)
        if not args.reuse_pred_weights:
            net.reset_class(args.dataset.get_classes(), reuse_weights=None)
        async_net = net

    if args.resume.strip():
        net.load_parameters(args.resume.strip())
        if args.meta_arch == 'yolo3':
            async_net.load_parameters(args.resume.strip())
    else:
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")
            net.initialize()
            if args.meta_arch == 'yolo3':
                async_net.initialize()

    # training data
    train_dataset, eval_metric = args.dataset.get_dataset_and_metric()
    if args.meta_arch == 'yolo3':
        train_data, val_data = get_dataloader(async_net, train_dataset, None,
                                              args.data_shape, args.batch_size,
                                              args.num_workers, args)
    elif args.meta_arch == 'faster_rcnn':
        train_data, val_data = get_faster_rcnn_dataloader(
            net, train_dataset, None, FasterRCNNDefaultTrainTransform,
            FasterRCNNDefaultValTransform, args.batch_size, args.num_gpus,
            args)

    # training
    train(net, train_data, val_data, eval_metric, ctx, args, reporter,
          args.final_fit)

    if args.final_fit:
        return {'model_params': collect_params(net)}
Beispiel #4
0
def train_image_classification(args, reporter):
    logging.basicConfig()
    logger = logging.getLogger(__name__)
    if args.verbose:
        logger.setLevel(logging.INFO)
        logger.info(args)

    target_params = Sample_params(args.batch_size, args.num_gpus,
                                  args.num_workers)
    batch_size = target_params.get_batchsize
    ctx = target_params.get_context
    classes = args.dataset.num_classes if hasattr(args.dataset,
                                                  'num_classes') else None
    target_kwargs = Getmodel_kwargs(ctx, classes, args.net,
                                    args.tricks.teacher_name,
                                    args.tricks.hard_weight, args.hybridize,
                                    args.optimizer.multi_precision,
                                    args.tricks.use_pretrained,
                                    args.tricks.use_gn, args.tricks.last_gamma,
                                    args.tricks.batch_norm, args.tricks.use_se)
    distillation = target_kwargs.distillation
    net = target_kwargs.get_net
    input_size = net.input_size if hasattr(net,
                                           'input_size') else args.input_size

    if args.tricks.no_wd:
        for k, v in net.collect_params('.*beta|.*gamma|.*bias').items():
            v.wd_mult = 0.0

    if args.tricks.label_smoothing or args.tricks.mixup:
        sparse_label_loss = False
    else:
        sparse_label_loss = True

    if distillation:
        teacher = target_kwargs.get_teacher

        def teacher_prob(data):
            return [
                nd.softmax(
                    teacher(X.astype(target_kwargs.dtype, copy=False)) /
                    args.tricks.temperature) for X in data
            ]

        L = DistillationSoftmaxCrossEntropyLoss(
            temperature=args.tricks.temperature,
            hard_weight=args.tricks.hard_weight,
            sparse_label=sparse_label_loss)
    else:
        L = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=sparse_label_loss)
        teacher_prob = None
    if args.tricks.mixup:
        metric = get_metric_instance('rmse')
    else:
        metric = get_metric_instance(args.metric)

    train_data, val_data, batch_fn, num_batches = get_data_loader(
        args.dataset, input_size, batch_size, args.num_workers, args.final_fit,
        args.split_ratio)

    if isinstance(args.lr_config.lr_mode, str):  # fix
        target_lr = LR_params(
            args.optimizer.lr, args.lr_config.lr_mode, args.epochs,
            num_batches, args.lr_config.lr_decay_epoch,
            args.lr_config.lr_decay, args.lr_config.lr_decay_period,
            args.lr_config.warmup_epochs, args.lr_config.warmup_lr)
        lr_scheduler = target_lr.get_lr_scheduler
    else:
        lr_scheduler = args.lr_config.lr_mode
    args.optimizer.lr_scheduler = lr_scheduler

    trainer = gluon.Trainer(net.collect_params(), args.optimizer)

    def train(epoch, num_epochs, metric):
        for i, batch in enumerate(train_data):
            metric = default_train_fn(
                epoch, num_epochs, net, batch, batch_size, L, trainer,
                batch_fn, ctx, args.tricks.mixup, args.tricks.label_smoothing,
                distillation, args.tricks.mixup_alpha,
                args.tricks.mixup_off_epoch, classes, target_kwargs.dtype,
                metric, teacher_prob)
            mx.nd.waitall()
        return metric

    def test(epoch):
        metric.reset()
        for i, batch in enumerate(val_data):
            default_val_fn(net, batch, batch_fn, metric, ctx,
                           target_kwargs.dtype)
        _, reward = metric.get()
        reporter(epoch=epoch, classification_reward=reward)
        return reward

    # Note: epoch must start with 1, not 0
    tbar = tqdm(range(1, args.epochs + 1))
    for epoch in tbar:
        metric = train(epoch, args.epochs, metric)
        train_metric_name, train_metric_score = metric.get()
        tbar.set_description(
            f'[Epoch {epoch}] training: {train_metric_name}={train_metric_score :.3f}'
        )
        if not args.final_fit:
            reward = test(epoch)
            tbar.set_description(f'[Epoch {epoch}] Validation: {reward :.3f}')

    if args.final_fit:
        return {'model_params': collect_params(net), 'num_classes': classes}