def run():
    logging_GOCD.init_logging(log_file_path=param_log_file_path,
                              log_file_mode=param_log_mode)

    logging.info('Preparing before training.')
    sys.path.append('..')
    from symbol_farm import symbol_10_320_20L_5scales_v2 as net

    net_symbol, data_names, label_names = net.get_net_symbol()
    net_initializer = mxnet.initializer.Xavier()

    logging.info('Get net symbol successfully.')

    # -----------------------------------------------------------------------------------------------
    # 构造dataiter
    from data_provider_farm.pickle_provider import PickleProvider
    from data_iterator_farm.multithread_dataiter_for_cross_entropy_v2 import Multithread_DataIter_for_CrossEntropy as DataIter

    train_data_provider = PickleProvider(param_trainset_pickle_file_path)
    train_dataiter = DataIter(
        mxnet_module=mxnet,
        num_threads=param_num_thread_train_dataiter,
        data_provider=train_data_provider,
        batch_size=param_train_batch_size,
        enable_horizon_flip=param_enable_horizon_flip,
        enable_vertical_flip=param_enable_vertical_flip,
        enable_random_brightness=param_enable_random_brightness,
        brightness_params=param_brightness_factors,
        enable_random_saturation=param_enable_random_saturation,
        saturation_params=param_saturation_factors,
        enable_random_contrast=param_enable_random_contrast,
        contrast_params=param_contrast_factors,
        enable_blur=param_enable_blur,
        blur_params=param_blur_factors,
        blur_kernel_size_list=param_blur_kernel_size_list,
        neg_image_ratio=param_neg_image_ratio,
        num_image_channels=param_num_image_channel,
        net_input_height=param_net_input_height,
        net_input_width=param_net_input_width,
        num_output_scales=param_num_output_scales,
        receptive_field_list=param_receptive_field_list,
        receptive_field_stride=param_receptive_field_stride,
        feature_map_size_list=param_feature_map_size_list,
        receptive_field_center_start=param_receptive_field_center_start,
        bbox_small_list=param_bbox_small_list,
        bbox_large_list=param_bbox_large_list,
        bbox_small_gray_list=param_bbox_small_gray_list,
        bbox_large_gray_list=param_bbox_large_gray_list,
        num_output_channels=param_num_output_channels,
        neg_image_resize_factor_interval=param_neg_image_resize_factor_interval
    )

    val_dataiter = None
    if param_valset_pickle_file_path != '' and param_val_batch_size != 0 and param_num_val_loops != 0 and param_num_thread_val_dataiter != 0:
        val_data_provider = PickleProvider(param_valset_pickle_file_path)
        val_dataiter = DataIter(
            mxnet_module=mxnet,
            num_threads=param_num_thread_val_dataiter,
            data_provider=val_data_provider,
            batch_size=param_val_batch_size,
            enable_horizon_flip=param_enable_horizon_flip,
            enable_vertical_flip=param_enable_vertical_flip,
            enable_random_brightness=param_enable_random_brightness,
            brightness_params=param_brightness_factors,
            enable_random_saturation=param_enable_random_saturation,
            saturation_params=param_saturation_factors,
            enable_random_contrast=param_enable_random_contrast,
            contrast_params=param_contrast_factors,
            enable_blur=param_enable_blur,
            blur_params=param_blur_factors,
            blur_kernel_size_list=param_blur_kernel_size_list,
            neg_image_ratio=param_neg_image_ratio,
            num_image_channels=param_num_image_channel,
            net_input_height=param_net_input_height,
            net_input_width=param_net_input_width,
            num_output_scales=param_num_output_scales,
            receptive_field_list=param_receptive_field_list,
            receptive_field_stride=param_receptive_field_stride,
            feature_map_size_list=param_feature_map_size_list,
            receptive_field_center_start=param_receptive_field_center_start,
            bbox_small_list=param_bbox_small_list,
            bbox_large_list=param_bbox_large_list,
            bbox_small_gray_list=param_bbox_small_gray_list,
            bbox_large_gray_list=param_bbox_large_gray_list,
            num_output_channels=param_num_output_channels,
            neg_image_resize_factor_interval=
            param_neg_image_resize_factor_interval)
    # ---------------------------------------------------------------------------------------------
    # 构造metric
    from metric_farm.metric_default import Metric

    train_metric = Metric(param_num_output_scales)
    val_metric = None
    if val_dataiter is not None:
        val_metric = Metric(param_num_output_scales)

    train_GOCD.start_train(
        param_dict=param_dict,
        mxnet_module=mxnet,
        context=[mxnet.gpu(i) for i in param_GPU_idx_list],
        train_dataiter=train_dataiter,
        train_metric=train_metric,
        train_metric_update_frequency=param_train_metric_update_frequency,
        num_train_loops=param_num_train_loops,
        val_dataiter=val_dataiter,
        val_metric=val_metric,
        num_val_loops=param_num_val_loops,
        validation_interval=param_validation_interval,
        optimizer_name=param_optimizer_name,
        optimizer_params=param_optimizer_params,
        net_symbol=net_symbol,
        net_initializer=net_initializer,
        net_data_names=data_names,
        net_label_names=label_names,
        pretrained_model_param_path=param_pretrained_model_param_path,
        display_interval=param_display_interval,
        save_prefix=param_save_prefix,
        model_save_interval=param_model_save_interval,
        start_index=param_start_index)
Exemple #2
0
def run():
    logging_GOCD.init_logging(log_file_path=param_log_file_path,
                              log_file_mode=param_log_mode)

    logging.info('Preparing before training.')
    sys.path.append('..')
    from net_farm.naivenet import naivenet20

    net = naivenet20()
    net_initializer = 'default'

    # if torch.cuda.is_available():
    #     net.cuda()
    #     cudnn.benchmark = True

    torch.cuda.set_device(param_gpu_id_list[0])
    net = net.cuda(param_gpu_id_list[0])

    # construct the learning rate scheduler
    param_optimizer = optim.SGD(net.parameters(),
                                lr=param_learning_rate,
                                momentum=param_momentum,
                                weight_decay=param_weight_decay)
    # param_lr_scheduler = optim.lr_scheduler.StepLR(param_optimizer, 500000, 0.1)
    param_lr_scheduler = optim.lr_scheduler.MultiStepLR(
        param_optimizer,
        milestones=param_scheduler_step_list,
        gamma=param_scheduler_factor)

    loss_criterion = cross_entropy_with_hnm_for_one_class_detection(
        param_hnm_ratio, param_num_output_scales)

    logging.info('Get net model successfully.')

    # -------------------------------------------------------------------------
    # init dataiter
    from data_provider_farm.pickle_provider import PickleProvider
    from data_iterator_farm.multithread_dataiter_for_cross_entropy_v2 import \
         Multithread_DataIter_for_CrossEntropy as DataIter

    train_data_provider = PickleProvider(param_trainset_pickle_file_path)
    train_dataiter = DataIter(
        torch_module=torch,
        num_threads=param_num_thread_train_dataiter,
        data_provider=train_data_provider,
        batch_size=param_train_batch_size,
        enable_horizon_flip=param_enable_horizon_flip,
        enable_vertical_flip=param_enable_vertical_flip,
        enable_random_brightness=param_enable_random_brightness,
        brightness_params=param_brightness_factors,
        enable_random_saturation=param_enable_random_saturation,
        saturation_params=param_saturation_factors,
        enable_random_contrast=param_enable_random_contrast,
        contrast_params=param_contrast_factors,
        enable_blur=param_enable_blur,
        blur_params=param_blur_factors,
        blur_kernel_size_list=param_blur_kernel_size_list,
        neg_image_ratio=param_neg_image_ratio,
        num_image_channels=param_num_image_channel,
        net_input_height=param_net_input_height,
        net_input_width=param_net_input_width,
        num_output_scales=param_num_output_scales,
        receptive_field_list=param_receptive_field_list,
        receptive_field_stride=param_receptive_field_stride,
        feature_map_size_list=param_feature_map_size_list,
        receptive_field_center_start=param_receptive_field_center_start,
        bbox_small_list=param_bbox_small_list,
        bbox_large_list=param_bbox_large_list,
        bbox_small_gray_list=param_bbox_small_gray_list,
        bbox_large_gray_list=param_bbox_large_gray_list,
        num_output_channels=param_num_output_channels,
        neg_image_resize_factor_interval=param_neg_image_resize_factor_interval
    )

    val_dataiter = None
    if param_valset_pickle_file_path != '' and param_val_batch_size != 0 and \
       param_num_val_loops != 0 and param_num_thread_val_dataiter != 0:
        val_data_provider = PickleProvider(param_valset_pickle_file_path)
        val_dataiter = DataIter(
            torch_module=torch,
            num_threads=param_num_thread_val_dataiter,
            data_provider=val_data_provider,
            batch_size=param_val_batch_size,
            enable_horizon_flip=param_enable_horizon_flip,
            enable_vertical_flip=param_enable_vertical_flip,
            enable_random_brightness=param_enable_random_brightness,
            brightness_params=param_brightness_factors,
            enable_random_saturation=param_enable_random_saturation,
            saturation_params=param_saturation_factors,
            enable_random_contrast=param_enable_random_contrast,
            contrast_params=param_contrast_factors,
            enable_blur=param_enable_blur,
            blur_params=param_blur_factors,
            blur_kernel_size_list=param_blur_kernel_size_list,
            neg_image_ratio=param_neg_image_ratio,
            num_image_channels=param_num_image_channel,
            net_input_height=param_net_input_height,
            net_input_width=param_net_input_width,
            num_output_scales=param_num_output_scales,
            receptive_field_list=param_receptive_field_list,
            receptive_field_stride=param_receptive_field_stride,
            feature_map_size_list=param_feature_map_size_list,
            receptive_field_center_start=param_receptive_field_center_start,
            bbox_small_list=param_bbox_small_list,
            bbox_large_list=param_bbox_large_list,
            bbox_small_gray_list=param_bbox_small_gray_list,
            bbox_large_gray_list=param_bbox_large_gray_list,
            num_output_channels=param_num_output_channels,
            neg_image_resize_factor_interval=
            param_neg_image_resize_factor_interval)
    # -------------------------------------------------------------------------
    # init metric
    from metric_farm.metric_default import Metric

    train_metric = Metric(param_num_output_scales)
    val_metric = None
    if val_dataiter is not None:
        val_metric = Metric(param_num_output_scales)

    train_GOCD.start_train(
        param_dict=param_dict,
        task_name=param_task_name,
        torch_module=torch,
        gpu_id_list=param_gpu_id_list,
        train_dataiter=train_dataiter,
        train_metric=train_metric,
        train_metric_update_frequency=param_train_metric_update_frequency,
        num_train_loops=param_num_train_loops,
        val_dataiter=val_dataiter,
        val_metric=val_metric,
        num_val_loops=param_num_val_loops,
        validation_interval=param_validation_interval,
        optimizer=param_optimizer,
        lr_scheduler=param_lr_scheduler,
        net=net,
        net_initializer=net_initializer,
        loss_criterion=loss_criterion,
        pretrained_model_param_path=param_pretrained_model_param_path,
        display_interval=param_display_interval,
        save_prefix=param_save_prefix,
        model_save_interval=param_model_save_interval,
        start_index=param_start_index)