def run(): logging_GOCD.init_logging(log_file_path=param_log_file_path, log_file_mode=param_log_mode) logging.info('Preparing before training.') sys.path.append('..') from symbol_farm import symbol_10_320_20L_5scales_v2 as net net_symbol, data_names, label_names = net.get_net_symbol() net_initializer = mxnet.initializer.Xavier() logging.info('Get net symbol successfully.') # ----------------------------------------------------------------------------------------------- # 构造dataiter from data_provider_farm.pickle_provider import PickleProvider from data_iterator_farm.multithread_dataiter_for_cross_entropy_v2 import Multithread_DataIter_for_CrossEntropy as DataIter train_data_provider = PickleProvider(param_trainset_pickle_file_path) train_dataiter = DataIter( mxnet_module=mxnet, num_threads=param_num_thread_train_dataiter, data_provider=train_data_provider, batch_size=param_train_batch_size, enable_horizon_flip=param_enable_horizon_flip, enable_vertical_flip=param_enable_vertical_flip, enable_random_brightness=param_enable_random_brightness, brightness_params=param_brightness_factors, enable_random_saturation=param_enable_random_saturation, saturation_params=param_saturation_factors, enable_random_contrast=param_enable_random_contrast, contrast_params=param_contrast_factors, enable_blur=param_enable_blur, blur_params=param_blur_factors, blur_kernel_size_list=param_blur_kernel_size_list, neg_image_ratio=param_neg_image_ratio, num_image_channels=param_num_image_channel, net_input_height=param_net_input_height, net_input_width=param_net_input_width, num_output_scales=param_num_output_scales, receptive_field_list=param_receptive_field_list, receptive_field_stride=param_receptive_field_stride, feature_map_size_list=param_feature_map_size_list, receptive_field_center_start=param_receptive_field_center_start, bbox_small_list=param_bbox_small_list, bbox_large_list=param_bbox_large_list, bbox_small_gray_list=param_bbox_small_gray_list, bbox_large_gray_list=param_bbox_large_gray_list, num_output_channels=param_num_output_channels, neg_image_resize_factor_interval=param_neg_image_resize_factor_interval ) val_dataiter = None if param_valset_pickle_file_path != '' and param_val_batch_size != 0 and param_num_val_loops != 0 and param_num_thread_val_dataiter != 0: val_data_provider = PickleProvider(param_valset_pickle_file_path) val_dataiter = DataIter( mxnet_module=mxnet, num_threads=param_num_thread_val_dataiter, data_provider=val_data_provider, batch_size=param_val_batch_size, enable_horizon_flip=param_enable_horizon_flip, enable_vertical_flip=param_enable_vertical_flip, enable_random_brightness=param_enable_random_brightness, brightness_params=param_brightness_factors, enable_random_saturation=param_enable_random_saturation, saturation_params=param_saturation_factors, enable_random_contrast=param_enable_random_contrast, contrast_params=param_contrast_factors, enable_blur=param_enable_blur, blur_params=param_blur_factors, blur_kernel_size_list=param_blur_kernel_size_list, neg_image_ratio=param_neg_image_ratio, num_image_channels=param_num_image_channel, net_input_height=param_net_input_height, net_input_width=param_net_input_width, num_output_scales=param_num_output_scales, receptive_field_list=param_receptive_field_list, receptive_field_stride=param_receptive_field_stride, feature_map_size_list=param_feature_map_size_list, receptive_field_center_start=param_receptive_field_center_start, bbox_small_list=param_bbox_small_list, bbox_large_list=param_bbox_large_list, bbox_small_gray_list=param_bbox_small_gray_list, bbox_large_gray_list=param_bbox_large_gray_list, num_output_channels=param_num_output_channels, neg_image_resize_factor_interval= param_neg_image_resize_factor_interval) # --------------------------------------------------------------------------------------------- # 构造metric from metric_farm.metric_default import Metric train_metric = Metric(param_num_output_scales) val_metric = None if val_dataiter is not None: val_metric = Metric(param_num_output_scales) train_GOCD.start_train( param_dict=param_dict, mxnet_module=mxnet, context=[mxnet.gpu(i) for i in param_GPU_idx_list], train_dataiter=train_dataiter, train_metric=train_metric, train_metric_update_frequency=param_train_metric_update_frequency, num_train_loops=param_num_train_loops, val_dataiter=val_dataiter, val_metric=val_metric, num_val_loops=param_num_val_loops, validation_interval=param_validation_interval, optimizer_name=param_optimizer_name, optimizer_params=param_optimizer_params, net_symbol=net_symbol, net_initializer=net_initializer, net_data_names=data_names, net_label_names=label_names, pretrained_model_param_path=param_pretrained_model_param_path, display_interval=param_display_interval, save_prefix=param_save_prefix, model_save_interval=param_model_save_interval, start_index=param_start_index)
def run(): logging_GOCD.init_logging(log_file_path=param_log_file_path, log_file_mode=param_log_mode) logging.info('Preparing before training.') sys.path.append('..') from net_farm.naivenet import naivenet20 net = naivenet20() net_initializer = 'default' # if torch.cuda.is_available(): # net.cuda() # cudnn.benchmark = True torch.cuda.set_device(param_gpu_id_list[0]) net = net.cuda(param_gpu_id_list[0]) # construct the learning rate scheduler param_optimizer = optim.SGD(net.parameters(), lr=param_learning_rate, momentum=param_momentum, weight_decay=param_weight_decay) # param_lr_scheduler = optim.lr_scheduler.StepLR(param_optimizer, 500000, 0.1) param_lr_scheduler = optim.lr_scheduler.MultiStepLR( param_optimizer, milestones=param_scheduler_step_list, gamma=param_scheduler_factor) loss_criterion = cross_entropy_with_hnm_for_one_class_detection( param_hnm_ratio, param_num_output_scales) logging.info('Get net model successfully.') # ------------------------------------------------------------------------- # init dataiter from data_provider_farm.pickle_provider import PickleProvider from data_iterator_farm.multithread_dataiter_for_cross_entropy_v2 import \ Multithread_DataIter_for_CrossEntropy as DataIter train_data_provider = PickleProvider(param_trainset_pickle_file_path) train_dataiter = DataIter( torch_module=torch, num_threads=param_num_thread_train_dataiter, data_provider=train_data_provider, batch_size=param_train_batch_size, enable_horizon_flip=param_enable_horizon_flip, enable_vertical_flip=param_enable_vertical_flip, enable_random_brightness=param_enable_random_brightness, brightness_params=param_brightness_factors, enable_random_saturation=param_enable_random_saturation, saturation_params=param_saturation_factors, enable_random_contrast=param_enable_random_contrast, contrast_params=param_contrast_factors, enable_blur=param_enable_blur, blur_params=param_blur_factors, blur_kernel_size_list=param_blur_kernel_size_list, neg_image_ratio=param_neg_image_ratio, num_image_channels=param_num_image_channel, net_input_height=param_net_input_height, net_input_width=param_net_input_width, num_output_scales=param_num_output_scales, receptive_field_list=param_receptive_field_list, receptive_field_stride=param_receptive_field_stride, feature_map_size_list=param_feature_map_size_list, receptive_field_center_start=param_receptive_field_center_start, bbox_small_list=param_bbox_small_list, bbox_large_list=param_bbox_large_list, bbox_small_gray_list=param_bbox_small_gray_list, bbox_large_gray_list=param_bbox_large_gray_list, num_output_channels=param_num_output_channels, neg_image_resize_factor_interval=param_neg_image_resize_factor_interval ) val_dataiter = None if param_valset_pickle_file_path != '' and param_val_batch_size != 0 and \ param_num_val_loops != 0 and param_num_thread_val_dataiter != 0: val_data_provider = PickleProvider(param_valset_pickle_file_path) val_dataiter = DataIter( torch_module=torch, num_threads=param_num_thread_val_dataiter, data_provider=val_data_provider, batch_size=param_val_batch_size, enable_horizon_flip=param_enable_horizon_flip, enable_vertical_flip=param_enable_vertical_flip, enable_random_brightness=param_enable_random_brightness, brightness_params=param_brightness_factors, enable_random_saturation=param_enable_random_saturation, saturation_params=param_saturation_factors, enable_random_contrast=param_enable_random_contrast, contrast_params=param_contrast_factors, enable_blur=param_enable_blur, blur_params=param_blur_factors, blur_kernel_size_list=param_blur_kernel_size_list, neg_image_ratio=param_neg_image_ratio, num_image_channels=param_num_image_channel, net_input_height=param_net_input_height, net_input_width=param_net_input_width, num_output_scales=param_num_output_scales, receptive_field_list=param_receptive_field_list, receptive_field_stride=param_receptive_field_stride, feature_map_size_list=param_feature_map_size_list, receptive_field_center_start=param_receptive_field_center_start, bbox_small_list=param_bbox_small_list, bbox_large_list=param_bbox_large_list, bbox_small_gray_list=param_bbox_small_gray_list, bbox_large_gray_list=param_bbox_large_gray_list, num_output_channels=param_num_output_channels, neg_image_resize_factor_interval= param_neg_image_resize_factor_interval) # ------------------------------------------------------------------------- # init metric from metric_farm.metric_default import Metric train_metric = Metric(param_num_output_scales) val_metric = None if val_dataiter is not None: val_metric = Metric(param_num_output_scales) train_GOCD.start_train( param_dict=param_dict, task_name=param_task_name, torch_module=torch, gpu_id_list=param_gpu_id_list, train_dataiter=train_dataiter, train_metric=train_metric, train_metric_update_frequency=param_train_metric_update_frequency, num_train_loops=param_num_train_loops, val_dataiter=val_dataiter, val_metric=val_metric, num_val_loops=param_num_val_loops, validation_interval=param_validation_interval, optimizer=param_optimizer, lr_scheduler=param_lr_scheduler, net=net, net_initializer=net_initializer, loss_criterion=loss_criterion, pretrained_model_param_path=param_pretrained_model_param_path, display_interval=param_display_interval, save_prefix=param_save_prefix, model_save_interval=param_model_save_interval, start_index=param_start_index)