def net_visualization(network=None,
                      num_classes=None,
                      data_shape=None,
                      train=None,
                      output_dir=None,
                      print_net=False,
                      net=None):
    # if you specify your net, this means that you are calling this function from somewhere else..
    if net is None:
        if not train:
            net = symbol_factory.get_symbol(network,
                                            data_shape,
                                            num_classes=num_classes)
        else:
            net = symbol_factory.get_symbol_train(network,
                                                  data_shape,
                                                  num_classes=num_classes)

    if not train:
        a = mx.viz.plot_network(net, shape={"data": (1, 3, data_shape, data_shape)}, \
                                node_attrs={"shape": 'rect', "fixedsize": 'false'})
        filename = "ssd_" + network + '_' + str(data_shape) + '_' + 'test'
    else:
        a = mx.viz.plot_network(net, shape=None, \
                                node_attrs={"shape": 'rect', "fixedsize": 'false'})
        filename = "ssd_" + network + '_' + 'train'

    a.render(os.path.join(output_dir, filename))
    if print_net:
        print(net.tojson())
Beispiel #2
0
def train_net(net,
              train_path,
              num_classes,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              solver,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              lr_refactor_step,
              lr_refactor_ratio,
              freeze_layer_pattern='',
              num_example=10000,
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              class_names=None,
              voc07_metric=False,
              nms_topk=400,
              force_suppress=False,
              train_list="",
              val_path="",
              val_list="",
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None,
              lite=False,
              kv_store=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    if lite:
        prefix += 'lite'
    prefix += '_' + net + '_' + str(data_shape[1]) + '_' + str(data_shape[2])
    print(prefix)

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path,
                               batch_size,
                               data_shape,
                               mean_pixels=mean_pixels,
                               label_pad_width=label_pad_width,
                               path_imglist=train_list,
                               **cfg.train)

    if val_path:
        val_iter = DetRecordIter(val_path,
                                 batch_size,
                                 data_shape,
                                 mean_pixels=mean_pixels,
                                 label_pad_width=label_pad_width,
                                 path_imglist=val_list,
                                 **cfg.valid)
    else:
        val_iter = None

    # load symbol
    net = get_symbol_train(net,
                           data_shape,
                           num_classes=num_classes,
                           nms_thresh=nms_thresh,
                           force_suppress=force_suppress,
                           nms_topk=nms_topk,
                           lite=lite)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}".format(
            ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # init training module
    mod = mx.mod.Module(net,
                        label_names=('label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size,
                                                 frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix, end_epoch / 2)
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                   lr_refactor_step,
                                                   lr_refactor_ratio,
                                                   num_example, batch_size,
                                                   begin_epoch)

    if solver == 'sgd':
        optimizer_params = {
            'learning_rate': learning_rate,
            'momentum': momentum,
            'wd': weight_decay,
            'lr_scheduler': lr_scheduler,
            'clip_gradient': None,
            'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0
        }
    elif solver == 'rmsprop':
        optimizer_params = {
            'learning_rate': learning_rate,
            'gamma1': 0.5,
            'wd': weight_decay,
            'lr_scheduler': lr_scheduler,
            'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0
        }

    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh,
                                      use_difficult,
                                      class_names,
                                      pred_idx=3)
    else:
        valid_metric = MApMetric(ovp_thresh,
                                 use_difficult,
                                 class_names,
                                 pred_idx=3)

    # create kvstore when there are gpus
    kv = mx.kvstore.create(kv_store) if kv_store else None

    mod.fit(train_iter,
            val_iter,
            eval_metric=MultiBoxMetric(),
            validation_metric=valid_metric,
            batch_end_callback=batch_end_callback,
            epoch_end_callback=epoch_end_callback,
            optimizer=solver,
            optimizer_params=optimizer_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=mx.init.Xavier(),
            arg_params=args,
            aux_params=auxs,
            allow_missing=True,
            monitor=monitor,
            kvstore=kv)
Beispiel #3
0
def train_net(network,
              train_path,
              num_classes,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              lr_refactor_step,
              lr_refactor_ratio,
              alpha_bb8=1.0,
              freeze_layer_pattern='',
              num_example=5717,
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              class_names=None,
              voc07_metric=False,
              nms_topk=400,
              force_suppress=False,
              train_list="",
              val_path="",
              val_list="",
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None,
              optimizer='sgd',
              tensorboard=False,
              checkpoint_period=5,
              min_neg_samples=0):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    optimizer : str
        usage of different optimizers, other then default sgd
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    tensorboard : bool
        record logs into tensorboard
    min_neg_samples : int
        always have some negative examples, no matter how many positive there are.
        this is useful when training on images with no ground-truth.
    checkpoint_period : int
        a checkpoint will be saved every "checkpoint_period" epochs
    """
    # check actual number of train_images
    if os.path.exists(train_path.replace('rec', 'idx')):
        with open(train_path.replace('rec', 'idx'), 'r') as f:
            txt = f.readlines()
        num_example = len(txt)
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        log_file_path = os.path.join(os.path.dirname(prefix), log_file)
        if not os.path.exists(os.path.dirname(log_file_path)):
            os.makedirs(os.path.dirname(log_file_path))
        fh = logging.FileHandler(log_file_path)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    if prefix.endswith('_'):
        prefix += '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path,
                               batch_size,
                               data_shape,
                               mean_pixels=mean_pixels,
                               label_pad_width=label_pad_width,
                               path_imglist=train_list,
                               **cfg.train)
    label = train_iter._batch.label[0].asnumpy()
    if val_path:
        val_iter = DetRecordIter(val_path,
                                 batch_size,
                                 data_shape,
                                 mean_pixels=mean_pixels,
                                 label_pad_width=label_pad_width,
                                 path_imglist=val_list,
                                 **cfg.valid)
        val_label = val_iter._batch.label[0].asnumpy()
    else:
        val_iter = None

    # load symbol
    net = get_symbol_train(network,
                           data_shape[1],
                           alpha_bb8,
                           num_classes=num_classes,
                           nms_thresh=nms_thresh,
                           force_suppress=force_suppress,
                           nms_topk=nms_topk,
                           minimum_negative_samples=min_neg_samples)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # check what layers mismatch with the loaded parameters
        exe = net.simple_bind(mx.cpu(),
                              data=(1, 3, 300, 300),
                              label=(1, 1, 5),
                              grad_req='null')
        arg_dict = exe.arg_dict
        fixed_param_names = []
        for k, v in arg_dict.items():
            if k in args:
                if v.shape != args[k].shape:
                    del args[k]
                    logging.info("Removed %s" % k)
                else:
                    if not 'pred' in k:
                        fixed_param_names.append(k)
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}".format(
            ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # visualize net - both train and test
    net_visualization(net=net,
                      network=network,
                      data_shape=data_shape[2],
                      output_dir=os.path.dirname(prefix),
                      train=True)
    # net_visualization(net=None, network=network, data_shape=data_shape[2],
    #                   output_dir=os.path.dirname(prefix), train=False, num_classes=num_classes)

    # init training module
    data_names = [k[0] for k in train_iter.provide_data]
    label_names = [k[0] for k in train_iter.provide_label]
    mod = mx.mod.Module(net,
                        data_names=data_names,
                        label_names=label_names,
                        logger=logger,
                        context=ctx,
                        fixed_param_names=fixed_param_names)

    batch_end_callback = []
    eval_end_callback = []
    epoch_end_callback = [
        mx.callback.do_checkpoint(prefix, period=checkpoint_period)
    ]

    # add logging to tensorboard
    if tensorboard:
        tensorboard_dir = os.path.join(os.path.dirname(prefix), 'logs')
        if not os.path.exists(tensorboard_dir):
            os.makedirs(os.path.join(tensorboard_dir, 'train', 'scalar'))
            os.makedirs(os.path.join(tensorboard_dir, 'train', 'dist'))
            os.makedirs(os.path.join(tensorboard_dir, 'val', 'roc'))
            os.makedirs(os.path.join(tensorboard_dir, 'val', 'scalar'))
            os.makedirs(os.path.join(tensorboard_dir, 'val', 'images'))
        batch_end_callback.append(
            ParseLogCallback(
                dist_logging_dir=os.path.join(tensorboard_dir, 'train',
                                              'dist'),
                scalar_logging_dir=os.path.join(tensorboard_dir, 'train',
                                                'scalar'),
                logfile_path=log_file_path,
                batch_size=batch_size,
                iter_monitor=iter_monitor,
                frequent=frequent))
        eval_end_callback.append(
            LogMetricsCallback(os.path.join(tensorboard_dir, 'val/scalar'),
                               'ssd',
                               global_step=0))
        # eval_end_callback.append(LogROCCallback(logging_dir=os.path.join(tensorboard_dir, 'val/roc'),
        #                                         roc_path=os.path.join(os.path.dirname(prefix), 'roc'),
        #                                         class_names=class_names))
        # eval_end_callback.append(LogDetectionsCallback(logging_dir=os.path.join(tensorboard_dir, 'val/images'),
        #                                                images_path=os.path.join(os.path.dirname(prefix), 'images'),
        #                                                class_names=class_names,batch_size=batch_size,mean_pixels=mean_pixels))

    # this callback should be the last in a serie of batch_callbacks
    # since it is resetting the metric evaluation every $frequent batches
    batch_end_callback.append(
        mx.callback.Speedometer(train_iter.batch_size, frequent=frequent))

    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                   lr_refactor_step,
                                                   lr_refactor_ratio,
                                                   num_example, batch_size,
                                                   begin_epoch)
    logger.info(
        "learning rate: {}, lr refactor step: {}, lr refactor ratio: {}, batch size: {}."
        .format(learning_rate, lr_refactor_step, lr_refactor_ratio,
                batch_size))
    # add possibility for different optimizer
    opt, opt_params = get_optimizer_params(optimizer=optimizer,
                                           learning_rate=learning_rate,
                                           momentum=momentum,
                                           weight_decay=weight_decay,
                                           lr_scheduler=lr_scheduler,
                                           ctx=ctx,
                                           logger=logger)
    logger.info("Optimizer: {}".format(opt))
    for k, v in opt_params.items():
        if k == 'lr_scheduler':
            continue
        logger.info("{}: {}".format(k, v))

    # TODO monitor the gradient flow as in 'https://github.com/dmlc/tensorboard/blob/master/docs/tutorial/understanding-vanish-gradient.ipynb'
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh,
                                      use_difficult,
                                      class_names,
                                      pred_idx=4,
                                      roc_output_path=os.path.join(
                                          os.path.dirname(prefix), 'roc'))
    else:
        valid_metric = MApMetric(ovp_thresh,
                                 use_difficult,
                                 class_names,
                                 pred_idx=4,
                                 roc_output_path=os.path.join(
                                     os.path.dirname(prefix), 'roc'))

    mod.fit(
        train_iter,
        val_iter,
        eval_metric=MultiBoxMetric(),
        validation_metric=MultiBoxMetric(
        ),  # use 'valid_metric' for calculate mAP
        batch_end_callback=batch_end_callback,
        eval_end_callback=eval_end_callback,
        epoch_end_callback=epoch_end_callback,
        optimizer=opt,
        optimizer_params=opt_params,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        initializer=mx.init.Xavier(),
        arg_params=args,
        aux_params=auxs,
        allow_missing=True,
        monitor=monitor)
def train_net_common(net, train_iter, val_iter, batch_size,
                     data_shape, resume, finetune, pretrained, epoch,
                     prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
                     momentum, weight_decay, use_plateau, lr_refactor_step, lr_refactor_ratio,
                     freeze_layer_pattern='',
                     num_example=10000, label_pad_width=350,
                     nms_thresh=0.45, force_suppress=False, ovp_thresh=0.5,
                     use_difficult=False, class_names=None,
                     optimizer_name='sgd',
                     voc07_metric=False, nms_topk=400,
                     iter_monitor=0,
                     monitor_pattern=".*", logger=None):
    """
    """
    # check args
    prefix += '_' + net + '_' + str(data_shape[1])

    # load symbol
    net_str = net
    net = get_symbol_train(net, data_shape[1], num_classes=len(class_names),
        nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}"
            .format(ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}"
            .format(ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        try:
            logger.info("Start training with {} from pretrained model {}"
                .format(ctx_str, pretrained))
            _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
            args = convert_pretrained(pretrained, args)
            if net_str == 'ssd_pva':
                args, auxs = convert_pvanet(args, auxs)
        except:
            logger.info("Failed to load the pretrained model. Start from scratch.")
            args = None
            auxs = None
            fixed_param_names = None
    else:
        logger.info("Experimental: start training from scratch with {}"
            .format(ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')

    # init training module
    if not use_plateau: # focal loss does not go well with plateau
        mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
                fixed_param_names=fixed_param_names)
    else:
        mod = PlateauModule(net, label_names=('label',), logger=logger, context=ctx,
                fixed_param_names=fixed_param_names)

    # robust parameter setting
    mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
    mod = set_mod_params(mod, args, auxs, logger)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent, auto_reset=True)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None
    optimizer_params={'learning_rate': learning_rate,
                      'wd': weight_decay,
                      'clip_gradient': 4.0,
                      'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }
    if optimizer_name == 'sgd':
        optimizer_params['momentum'] = momentum

    # #7847
    mod.init_optimizer(optimizer=optimizer_name, optimizer_params=optimizer_params, force_init=True)

    if not use_plateau:
        learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
                lr_refactor_ratio, num_example, batch_size, begin_epoch)
    else:
        w_l1 = cfg.train['smoothl1_weight']
        eval_weights = {'CrossEntropy': 1.0, 'SmoothL1': w_l1}
        plateau_lr = PlateauScheduler( \
                patient_epochs=lr_refactor_step, factor=float(lr_refactor_ratio), eval_weights=eval_weights)
        plateau_metric = MultiBoxMetric(fn_stat='/home/hyunjoon/github/additions_mxnet/ssd/stat.txt')

    eval_metric = MultiBoxMetric()
    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4)
    else:
        valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=4)

    if not use_plateau:
        mod.fit(train_iter,
                eval_data=val_iter,
                eval_metric=eval_metric,
                validation_metric=valid_metric,
                batch_end_callback=batch_end_callback,
                epoch_end_callback=epoch_end_callback,
                optimizer=optimizer_name,
                optimizer_params=optimizer_params,
                begin_epoch=begin_epoch,
                num_epoch=end_epoch,
                initializer=mx.init.Xavier(),
                arg_params=args,
                aux_params=auxs,
                allow_missing=True,
                monitor=monitor)
    else:
        mod.fit(train_iter,
                plateau_lr, plateau_metric=plateau_metric,
                fn_curr_model=prefix+'-1000.params',
                plateau_backtrace=False,
                eval_data=val_iter,
                eval_metric=eval_metric,
                validation_metric=valid_metric,
                validation_period=5,
                batch_end_callback=batch_end_callback,
                epoch_end_callback=epoch_end_callback,
                optimizer=optimizer_name,
                optimizer_params=optimizer_params,
                begin_epoch=begin_epoch,
                num_epoch=end_epoch,
                initializer=mx.init.Xavier(),
                arg_params=args,
                aux_params=auxs,
                allow_missing=True,
                monitor=monitor)
def train_net(network, train_path, num_classes, batch_size,
              data_shape, mean_pixels, resume, finetune, pretrained, epoch,
              prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
              momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,
              freeze_layer_pattern='',
              num_example=10000, label_pad_width=350,
              nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,
              use_difficult=False, class_names=None,
              voc07_metric=False, nms_topk=400, force_suppress=False,
              train_list="", val_path="", val_list="", iter_monitor=0,
              monitor_pattern=".*", log_file=None, optimizer='sgd', tensorboard=False,
              checkpoint_period=5, min_neg_samples=0):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    optimizer : str
        usage of different optimizers, other then default sgd
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    tensorboard : bool
        record logs into tensorboard
    min_neg_samples : int
        always have some negative examples, no matter how many positive there are.
        this is useful when training on images with no ground-truth.
    checkpoint_period : int
        a checkpoint will be saved every "checkpoint_period" epochs
    """
    # check actual number of train_images
    if os.path.exists(train_path.replace('rec','idx')):
        with open(train_path.replace('rec','idx'), 'r') as f:
            txt = f.readlines()
        num_example = len(txt)
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        log_file_path = os.path.join(os.path.dirname(prefix), log_file)
        if not os.path.exists(os.path.dirname(log_file_path)):
            os.makedirs(os.path.dirname(log_file_path))
        fh = logging.FileHandler(log_file_path)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    if prefix.endswith('_'):
        prefix += '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
                               label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)

    if val_path:
        val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
                                 label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
    else:
        val_iter = None

    # load symbol
    net = get_symbol_train(network, data_shape[1], num_classes=num_classes,
                           nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk, minimum_negative_samples=min_neg_samples)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}"
                    .format(ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}"
                    .format(ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # check what layers mismatch with the loaded parameters
        exe = net.simple_bind(mx.cpu(), data=(1, 3, 300, 300), label=(1, 1, 5), grad_req='null')
        arg_dict = exe.arg_dict
        fixed_param_names = []
        for k, v in arg_dict.items():
            if k in args:
                if v.shape != args[k].shape:
                    del args[k]
                    logging.info("Removed %s" % k)
                else:
                    if not 'pred' in k:
                        fixed_param_names.append(k)
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}"
                    .format(ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}"
                    .format(ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')

    # visualize net - both train and test
    net_visualization(net=net, network=network,data_shape=data_shape[2],
                      output_dir=os.path.dirname(prefix), train=True)
    net_visualization(net=None, network=network, data_shape=data_shape[2],
                      output_dir=os.path.dirname(prefix), train=False, num_classes=num_classes)

    # init training module
    mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
                        fixed_param_names=fixed_param_names)

    batch_end_callback = []
    eval_end_callback = []
    epoch_end_callback = [mx.callback.do_checkpoint(prefix, period=checkpoint_period)]

    # add logging to tensorboard
    if tensorboard:
        tensorboard_dir = os.path.join(os.path.dirname(prefix), 'logs')
        if not os.path.exists(tensorboard_dir):
            os.makedirs(os.path.join(tensorboard_dir, 'train', 'scalar'))
            os.makedirs(os.path.join(tensorboard_dir, 'train', 'dist'))
            os.makedirs(os.path.join(tensorboard_dir, 'val', 'roc'))
            os.makedirs(os.path.join(tensorboard_dir, 'val', 'scalar'))
            os.makedirs(os.path.join(tensorboard_dir, 'val', 'images'))
        batch_end_callback.append(
            ParseLogCallback(dist_logging_dir=os.path.join(tensorboard_dir, 'train', 'dist'),
                             scalar_logging_dir=os.path.join(tensorboard_dir, 'train', 'scalar'),
                             logfile_path=log_file_path, batch_size=batch_size, iter_monitor=iter_monitor,
                             frequent=frequent))
        eval_end_callback.append(mx.contrib.tensorboard.LogMetricsCallback(
            os.path.join(tensorboard_dir, 'val/scalar'), 'ssd'))
        eval_end_callback.append(LogROCCallback(logging_dir=os.path.join(tensorboard_dir, 'val/roc'),
                                                roc_path=os.path.join(os.path.dirname(prefix), 'roc'),
                                                class_names=class_names))
        eval_end_callback.append(LogDetectionsCallback(logging_dir=os.path.join(tensorboard_dir, 'val/images'),
                                                       images_path=os.path.join(os.path.dirname(prefix), 'images'),
                                                       class_names=class_names,batch_size=batch_size,mean_pixels=mean_pixels))

    # this callback should be the last in a serie of batch_callbacks
    # since it is resetting the metric evaluation every $frequent batches
    batch_end_callback.append(mx.callback.Speedometer(train_iter.batch_size, frequent=frequent))

    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
                                                   lr_refactor_ratio, num_example, batch_size, begin_epoch)
    # add possibility for different optimizer
    opt, opt_params = get_optimizer_params(optimizer=optimizer, learning_rate=learning_rate, momentum=momentum,
                                           weight_decay=weight_decay, lr_scheduler=lr_scheduler, ctx=ctx, logger=logger)
    # TODO monitor the gradient flow as in 'https://github.com/dmlc/tensorboard/blob/master/docs/tutorial/understanding-vanish-gradient.ipynb'
    monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3,
                                      roc_output_path=os.path.join(os.path.dirname(prefix), 'roc'))
    else:
        valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3,
                                 roc_output_path=os.path.join(os.path.dirname(prefix), 'roc'))

    mod.fit(train_iter,
            val_iter,
            eval_metric=MultiBoxMetric(),
            validation_metric=valid_metric,
            batch_end_callback=batch_end_callback,
            eval_end_callback=eval_end_callback,
            epoch_end_callback=epoch_end_callback,
            optimizer=opt,
            optimizer_params=opt_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=mx.init.Xavier(),
            arg_params=args,
            aux_params=auxs,
            allow_missing=True,
            monitor=monitor)
Beispiel #6
0
def train_net(net, train_path, num_classes, batch_size,
              data_shape, mean_pixels, resume, finetune, pretrained, epoch,
              prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
              momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,
              freeze_layer_pattern='',
              num_example=10000, label_pad_width=350,
              nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,
              use_difficult=False, class_names=None,
              voc07_metric=False, nms_topk=400, force_suppress=False,
              train_list="", val_path="", val_list="", iter_monitor=0,
              monitor_pattern=".*", log_file=None, kv_store=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + net + '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
        label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)

    if val_path:
        val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
            label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
    else:
        val_iter = None

    # load symbol
    net = get_symbol_train(net, data_shape[1], num_classes=num_classes,
        nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}"
            .format(ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}"
            .format(ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}"
            .format(ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}"
            .format(ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')

    # init training module
    mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
        lr_refactor_ratio, num_example, batch_size, begin_epoch)
    optimizer_params={'learning_rate':learning_rate,
                      'momentum':momentum,
                      'wd':weight_decay,
                      'lr_scheduler':lr_scheduler,
                      'clip_gradient':None,
                      'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }
    monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
    else:
        valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)

    # create kvstore when there are gpus
    kv = mx.kvstore.create(kv_store) if kv_store else None

    mod.fit(train_iter,
            val_iter,
            eval_metric=MultiBoxMetric(),
            validation_metric=valid_metric,
            batch_end_callback=batch_end_callback,
            epoch_end_callback=epoch_end_callback,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=mx.init.Xavier(),
            arg_params=args,
            aux_params=auxs,
            allow_missing=True,
            monitor=monitor,
            kvstore=kv)
parser.add_argument('--network',
                    type=str,
                    default='vgg16_reduced',
                    help='the cnn to use')
parser.add_argument('--num-classes',
                    type=int,
                    default=20,
                    help='the number of classes')
parser.add_argument('--data-shape',
                    type=int,
                    default=300,
                    help='set image\'s shape')
parser.add_argument('--train',
                    action='store_true',
                    default=False,
                    help='show train net')
args = parser.parse_args()

if not args.train:
    net = symbol_factory.get_symbol(args.network,
                                    args.data_shape,
                                    num_classes=args.num_classes)
    a = mx.viz.plot_network(net, shape={"data":(1,3,args.data_shape,args.data_shape)}, \
        node_attrs={"shape":'rect', "fixedsize":'false'})
    a.render("ssd_" + args.network + '_' + str(args.data_shape))
else:
    net = symbol_factory.get_symbol_train(args.network,
                                          args.data_shape,
                                          num_classes=args.num_classes)
    print(net.tojson())
Beispiel #8
0
def train_net(net,
              train_path,
              num_classes,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              use_plateau,
              lr_refactor_step,
              lr_refactor_ratio,
              use_global_stats=0,
              freeze_layer_pattern='',
              num_example=10000,
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              class_names=None,
              ignore_names=None,
              optimizer_name='sgd',
              voc07_metric=False,
              nms_topk=400,
              force_suppress=False,
              train_list="",
              val_path="",
              val_list="",
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + net + '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path,
                               batch_size,
                               data_shape,
                               mean_pixels=mean_pixels,
                               label_pad_width=label_pad_width,
                               path_imglist=train_list,
                               **cfg.train)

    if val_path:
        val_iter = DetRecordIter(val_path,
                                 batch_size,
                                 data_shape,
                                 mean_pixels=mean_pixels,
                                 label_pad_width=label_pad_width,
                                 path_imglist=val_list,
                                 **cfg.valid)
    else:
        val_iter = None

    # load symbol
    net_str = net
    net = get_symbol_train(net, data_shape[1], \
            use_global_stats=use_global_stats, \
            num_classes=num_classes, ignore_names=ignore_names, \
            nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        try:
            logger.info(
                "Start training with {} from pretrained model {}".format(
                    ctx_str, pretrained))
            _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
            args = convert_pretrained(pretrained, args)
            if net_str == 'ssd_pva':
                args, auxs = convert_pvanet(args, auxs)
        except:
            logger.info(
                "Failed to load the pretrained model. Start from scratch.")
            args = None
            auxs = None
            fixed_param_names = None
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # init training module
    if not use_plateau:  # focal loss does not go well with plateau
        mod = mx.mod.Module(net,
                            label_names=('label', ),
                            logger=logger,
                            context=ctx,
                            fixed_param_names=fixed_param_names)
    else:
        mod = PlateauModule(net,
                            label_names=('label', ),
                            logger=logger,
                            context=ctx,
                            fixed_param_names=fixed_param_names)

    # robust parameter setting
    mod.bind(data_shapes=train_iter.provide_data,
             label_shapes=train_iter.provide_label)
    mod = set_mod_params(mod, args, auxs, logger)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size,
                                                 frequent=frequent,
                                                 auto_reset=True)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None
    optimizer_params = {
        'learning_rate': learning_rate,
        'wd': weight_decay,
        'clip_gradient': 4.0,
        'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0
    }
    if optimizer_name == 'sgd':
        optimizer_params['momentum'] = momentum

    # #7847
    mod.init_optimizer(optimizer=optimizer_name,
                       optimizer_params=optimizer_params,
                       force_init=True)

    if not use_plateau:
        learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                       lr_refactor_step,
                                                       lr_refactor_ratio,
                                                       num_example, batch_size,
                                                       begin_epoch)
    else:
        w_l1 = cfg.train['smoothl1_weight']
        eval_weights = {
            'CrossEntropy': 1.0,
            'SmoothL1': w_l1,
            'ObjectRecall': 0.0
        }
        plateau_lr = PlateauScheduler( \
                patient_epochs=lr_refactor_step, factor=float(lr_refactor_ratio), eval_weights=eval_weights)
        plateau_metric = MultiBoxMetric(
            fn_stat='/home/hyunjoon/github/additions_mxnet/ssd/stat.txt')

    mod.init_optimizer(optimizer=optimizer_name,
                       optimizer_params=optimizer_params)

    eval_metric = MultiBoxMetric()
    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        map_metric = VOC07MApMetric(ovp_thresh,
                                    use_difficult,
                                    class_names,
                                    pred_idx=4)
        recall_metric = RecallMetric(ovp_thresh, use_difficult, pred_idx=4)
        valid_metric = mx.metric.create([map_metric, recall_metric])
    else:
        valid_metric = MApMetric(ovp_thresh,
                                 use_difficult,
                                 class_names,
                                 pred_idx=4)

    if not use_plateau:
        mod.fit(train_iter,
                eval_data=val_iter,
                eval_metric=eval_metric,
                validation_metric=valid_metric,
                batch_end_callback=batch_end_callback,
                epoch_end_callback=epoch_end_callback,
                optimizer=optimizer_name,
                optimizer_params=optimizer_params,
                begin_epoch=begin_epoch,
                num_epoch=end_epoch,
                initializer=mx.init.Xavier(),
                arg_params=args,
                aux_params=auxs,
                allow_missing=True,
                monitor=monitor)
    else:
        mod.fit(train_iter,
                plateau_lr,
                plateau_metric=plateau_metric,
                fn_curr_model=prefix + '-1000.params',
                plateau_backtrace=False,
                eval_data=val_iter,
                eval_metric=eval_metric,
                validation_metric=valid_metric,
                validation_period=5,
                kvstore='local',
                batch_end_callback=batch_end_callback,
                epoch_end_callback=epoch_end_callback,
                optimizer=optimizer_name,
                optimizer_params=optimizer_params,
                begin_epoch=begin_epoch,
                num_epoch=end_epoch,
                initializer=mx.init.Xavier(),
                arg_params=args,
                aux_params=auxs,
                allow_missing=True,
                monitor=monitor)
Beispiel #9
0
def train_net(net,
              train_path,
              num_classes,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              lr_refactor_step,
              lr_refactor_ratio,
              freeze_layer_pattern='',
              num_example=10000,
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              class_names=None,
              voc07_metric=False,
              nms_topk=2000,
              force_suppress=False,
              train_list="",
              val_path="",
              val_list="",
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    # set a log
    if log_file:
        # crate a fileHandler
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3
    prefix += '_' + net + '_' + str(data_shape[1])
    # check the mean_pixels is list
    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    train_iter = DetRecordIter(train_path,
                               batch_size,
                               data_shape,
                               mean_pixels=mean_pixels,
                               label_pad_width=label_pad_width,
                               path_imglist=train_list,
                               **cfg.train)

    # for c in range(12840):
    #     batch = train_iter.next()
    #     data=batch.data[0]
    #     label=batch.label[0]
    #     from matplotlib import  pyplot as plt
    #     import numpy as np
    #     import cv2
    #     for i in range(2):
    #         plt.subplot(1,2,i+1)
    #         img = np.array(data[i].asnumpy().transpose(1,2,0).copy(), np.uint8)
    #         box = label[i].asnumpy()
    #         bbox = []
    #         print 'The', i, 'th image'
    #         for j in range(box.shape[0]):
    #             if box[j][0] == -1:
    #                 break
    #             else:
    #                 bbox.append(box[j][1:5])
    #         for k in range(len(bbox)):
    #             xmin = (bbox[k][0] * img.shape[0]).astype(np.int16)
    #             ymin = (bbox[k][1] * img.shape[0]).astype(np.int16)
    #             xmax = (bbox[k][2] * img.shape[0]).astype(np.int16)
    #             ymax = (bbox[k][3] * img.shape[0]).astype(np.int16)
    #             cv2.rectangle(img, (xmin,ymin), (xmax,ymax), (255,0,0),4)
    #
    #             print 'xmin', xmin, 'ymin', ymin, 'xmax', xmax, 'ymax', ymax
    #         plt.imshow(img)
    #     plt.show()
    #     #path = 'crop_image/'+ str(c) + '.jpg'
    #     #plt.savefig(path)
    #     print batch

    if val_path:
        val_iter = DetRecordIter(val_path,
                                 batch_size,
                                 data_shape,
                                 mean_pixels=mean_pixels,
                                 label_pad_width=label_pad_width,
                                 path_imglist=val_list,
                                 **cfg.valid)
    else:
        val_iter = None
    # load symbol
    net = get_symbol_train(net,
                           data_shape[1],
                           num_classes=num_classes,
                           nms_thresh=nms_thresh,
                           force_suppress=force_suppress,
                           nms_topk=nms_topk)
    # viz = mx.viz.plot_network(net)
    # viz.view()
    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}".format(
            ctx_str, pretrained))
        fixed_param_names = None
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # init training module
    mod = mx.mod.Module(net,
                        label_names=('label', ),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size,
                                                 frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                   lr_refactor_step,
                                                   lr_refactor_ratio,
                                                   num_example, batch_size,
                                                   begin_epoch)
    optimizer_params = {
        'learning_rate': learning_rate,
        'momentum': momentum,
        'wd': weight_decay,
        'lr_scheduler': lr_scheduler,
        'clip_gradient': None,
        'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0
    }
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        valid_metric = VOC07MApMetric(ovp_thresh,
                                      use_difficult,
                                      class_names,
                                      pred_idx=3)
    else:
        valid_metric = MApMetric(ovp_thresh,
                                 use_difficult,
                                 class_names,
                                 pred_idx=3)

    mod.fit(
        train_data=train_iter,  #train_iter,
        eval_data=val_iter,
        eval_metric=MultiBoxMetric(),
        validation_metric=valid_metric,
        batch_end_callback=batch_end_callback,
        epoch_end_callback=epoch_end_callback,
        optimizer='sgd',
        optimizer_params=optimizer_params,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        initializer=mx.init.Xavier(),
        arg_params=args,
        aux_params=auxs,
        allow_missing=True,
        monitor=monitor)