def get_model(args): '''get_model''' net = get_backbone(args) if args.fp16: net.add_flags_recursive(fp16=True) if args.weight.endswith('.ckpt'): param_dict = load_checkpoint(args.weight) param_dict_new = {} for key, value in param_dict.items(): if key.startswith('moments.'): continue elif key.startswith('network.'): param_dict_new[key[8:]] = value else: param_dict_new[key] = value load_param_into_net(net, param_dict_new) args.logger.info( 'INFO, ------------- load model success--------------') else: args.logger.info( 'ERROR, not support file:{}, please check weight in config.py'. format(args.weight)) return 0 net.set_train(False) return net
def main(args): network = get_backbone(args) ckpt_path = args.pretrained if os.path.isfile(ckpt_path): param_dict = load_checkpoint(ckpt_path) param_dict_new = {} for key, values in param_dict.items(): if key.startswith('moments.'): continue elif key.startswith('network.'): param_dict_new[key[8:]] = values else: param_dict_new[key] = values load_param_into_net(network, param_dict_new) print( '-----------------------load model success-----------------------') else: print( '-----------------------load model failed -----------------------') network.add_flags_recursive(fp16=True) network.set_train(False) input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 112, 112)).astype(np.float32) tensor_input_data = Tensor(input_data) file_path = ckpt_path.replace('.ckpt', '_' + str(args.batch_size) + 'b.air') export(network, tensor_input_data, file_name=file_path, file_format='AIR') print( '-----------------------export model success, save file:{}-----------------------' .format(file_path))
def run_train(): '''run train function.''' config.local_rank = get_rank_id() config.world_size = get_device_num() log_path = os.path.join(config.ckpt_path, 'logs') config.logger = get_logger(log_path, config.local_rank) support_train_stage = ['base', 'beta'] if config.train_stage.lower() not in support_train_stage: config.logger.info('your train stage is not support.') raise ValueError('train stage not support.') if not os.path.exists(config.data_dir): config.logger.info( 'ERROR, data_dir is not exists, please set data_dir in config.py') raise ValueError( 'ERROR, data_dir is not exists, please set data_dir in config.py') parallel_mode = ParallelMode.HYBRID_PARALLEL if config.is_distributed else ParallelMode.STAND_ALONE context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=config.world_size, gradients_mean=True) if config.is_distributed: init() if config.local_rank % 8 == 0: if not os.path.exists(config.ckpt_path): os.makedirs(config.ckpt_path) de_dataset, steps_per_epoch, num_classes = get_de_dataset(config) config.logger.info('de_dataset: %d', de_dataset.get_dataset_size()) config.steps_per_epoch = steps_per_epoch config.num_classes = num_classes config.lr_epochs = list(map(int, config.lr_epochs.split(','))) config.logger.info('config.num_classes: %d', config.num_classes) config.logger.info('config.world_size: %d', config.world_size) config.logger.info('config.local_rank: %d', config.local_rank) config.logger.info('config.lr: %f', config.lr) if config.nc_16 == 1: if config.model_parallel == 0: if config.num_classes % 16 == 0: config.logger.info('data parallel aleardy 16, nums: %d', config.num_classes) else: config.num_classes = (config.num_classes // 16 + 1) * 16 else: if config.num_classes % (config.world_size * 16) == 0: config.logger.info('model parallel aleardy 16, nums: %d', config.num_classes) else: config.num_classes = (config.num_classes // (config.world_size * 16) + 1) * config.world_size * 16 config.logger.info('for D, loaded, class nums: %d', config.num_classes) config.logger.info('steps_per_epoch: %d', config.steps_per_epoch) config.logger.info('img_total_num: %d', config.steps_per_epoch * config.per_batch_size) config.logger.info('get_backbone----in----') _backbone = get_backbone(config) config.logger.info('get_backbone----out----') config.logger.info('get_metric_fc----in----') margin_fc_1 = get_metric_fc(config) config.logger.info('get_metric_fc----out----') config.logger.info('DistributedHelper----in----') network_1 = DistributedHelper(_backbone, margin_fc_1) config.logger.info('DistributedHelper----out----') config.logger.info('network fp16----in----') if config.fp16 == 1: network_1.add_flags_recursive(fp16=True) config.logger.info('network fp16----out----') criterion_1 = get_loss(config) if config.fp16 == 1 and config.model_parallel == 0: criterion_1.add_flags_recursive(fp32=True) network_1 = load_pretrain(config, network_1) train_net = BuildTrainNetwork(network_1, criterion_1, config) # call warmup_step should behind the config steps_per_epoch config.lrs = warmup_step_list(config, gamma=0.1) lrs_gen = list_to_gen(config.lrs) opt = Momentum(params=train_net.trainable_params(), learning_rate=lrs_gen, momentum=config.momentum, weight_decay=config.weight_decay) scale_manager = DynamicLossScaleManager( init_loss_scale=config.dynamic_init_loss_scale, scale_factor=2, scale_window=2000) model = Model(train_net, optimizer=opt, metrics=None, loss_scale_manager=scale_manager) save_checkpoint_steps = config.ckpt_steps config.logger.info('save_checkpoint_steps: %d', save_checkpoint_steps) if config.max_ckpts == -1: keep_checkpoint_max = int(config.steps_per_epoch * config.max_epoch / save_checkpoint_steps) + 5 else: keep_checkpoint_max = config.max_ckpts config.logger.info('keep_checkpoint_max: %d', keep_checkpoint_max) ckpt_config = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps, keep_checkpoint_max=keep_checkpoint_max) config.logger.info('max_epoch_train: %d', config.max_epoch) ckpt_cb = ModelCheckpoint(config=ckpt_config, directory=config.ckpt_path, prefix='{}'.format(config.local_rank)) config.epoch_cnt = 0 progress_cb = ProgressMonitor(config) new_epoch_train = config.max_epoch * steps_per_epoch // config.log_interval model.train(new_epoch_train, de_dataset, callbacks=[progress_cb, ckpt_cb], sink_size=config.log_interval)
else: if args.num_classes % (args.world_size * 16) == 0: args.logger.info('model parallel aleardy 16, nums: {}'.format( args.num_classes)) else: args.num_classes = (args.num_classes // (args.world_size * 16) + 1) * args.world_size * 16 args.logger.info('for D, loaded, class nums: {}'.format(args.num_classes)) args.logger.info('steps_per_epoch:{}'.format(args.steps_per_epoch)) args.logger.info('img_total_num:{}'.format(args.steps_per_epoch * args.per_batch_size)) args.logger.info('get_backbone----in----') _backbone = get_backbone(args) args.logger.info('get_backbone----out----') args.logger.info('get_metric_fc----in----') margin_fc_1 = get_metric_fc(args) args.logger.info('get_metric_fc----out----') args.logger.info('DistributedHelper----in----') network_1 = DistributedHelper(_backbone, margin_fc_1) args.logger.info('DistributedHelper----out----') args.logger.info('network fp16----in----') if args.fp16 == 1: network_1.add_flags_recursive(fp16=True) args.logger.info('network fp16----out----')