示例#1
0
def get_model(args):
    '''get_model'''
    net = get_backbone(args)
    if args.fp16:
        net.add_flags_recursive(fp16=True)

    if args.weight.endswith('.ckpt'):
        param_dict = load_checkpoint(args.weight)
        param_dict_new = {}
        for key, value in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = value
            else:
                param_dict_new[key] = value
        load_param_into_net(net, param_dict_new)
        args.logger.info(
            'INFO, ------------- load model success--------------')
    else:
        args.logger.info(
            'ERROR, not support file:{}, please check weight in config.py'.
            format(args.weight))
        return 0
    net.set_train(False)
    return net
示例#2
0
def main(args):
    network = get_backbone(args)

    ckpt_path = args.pretrained
    if os.path.isfile(ckpt_path):
        param_dict = load_checkpoint(ckpt_path)
        param_dict_new = {}
        for key, values in param_dict.items():
            if key.startswith('moments.'):
                continue
            elif key.startswith('network.'):
                param_dict_new[key[8:]] = values
            else:
                param_dict_new[key] = values
        load_param_into_net(network, param_dict_new)
        print(
            '-----------------------load model success-----------------------')
    else:
        print(
            '-----------------------load model failed -----------------------')

    network.add_flags_recursive(fp16=True)
    network.set_train(False)

    input_data = np.random.uniform(low=0,
                                   high=1.0,
                                   size=(args.batch_size, 3, 112,
                                         112)).astype(np.float32)
    tensor_input_data = Tensor(input_data)

    file_path = ckpt_path.replace('.ckpt',
                                  '_' + str(args.batch_size) + 'b.air')
    export(network, tensor_input_data, file_name=file_path, file_format='AIR')
    print(
        '-----------------------export model success, save file:{}-----------------------'
        .format(file_path))
示例#3
0
def run_train():
    '''run train function.'''
    config.local_rank = get_rank_id()
    config.world_size = get_device_num()
    log_path = os.path.join(config.ckpt_path, 'logs')
    config.logger = get_logger(log_path, config.local_rank)

    support_train_stage = ['base', 'beta']
    if config.train_stage.lower() not in support_train_stage:
        config.logger.info('your train stage is not support.')
        raise ValueError('train stage not support.')

    if not os.path.exists(config.data_dir):
        config.logger.info(
            'ERROR, data_dir is not exists, please set data_dir in config.py')
        raise ValueError(
            'ERROR, data_dir is not exists, please set data_dir in config.py')

    parallel_mode = ParallelMode.HYBRID_PARALLEL if config.is_distributed else ParallelMode.STAND_ALONE
    context.set_auto_parallel_context(parallel_mode=parallel_mode,
                                      device_num=config.world_size,
                                      gradients_mean=True)
    if config.is_distributed:
        init()

    if config.local_rank % 8 == 0:
        if not os.path.exists(config.ckpt_path):
            os.makedirs(config.ckpt_path)

    de_dataset, steps_per_epoch, num_classes = get_de_dataset(config)
    config.logger.info('de_dataset: %d', de_dataset.get_dataset_size())

    config.steps_per_epoch = steps_per_epoch
    config.num_classes = num_classes
    config.lr_epochs = list(map(int, config.lr_epochs.split(',')))
    config.logger.info('config.num_classes: %d', config.num_classes)
    config.logger.info('config.world_size: %d', config.world_size)
    config.logger.info('config.local_rank: %d', config.local_rank)
    config.logger.info('config.lr: %f', config.lr)

    if config.nc_16 == 1:
        if config.model_parallel == 0:
            if config.num_classes % 16 == 0:
                config.logger.info('data parallel aleardy 16, nums: %d',
                                   config.num_classes)
            else:
                config.num_classes = (config.num_classes // 16 + 1) * 16
        else:
            if config.num_classes % (config.world_size * 16) == 0:
                config.logger.info('model parallel aleardy 16, nums: %d',
                                   config.num_classes)
            else:
                config.num_classes = (config.num_classes //
                                      (config.world_size * 16) +
                                      1) * config.world_size * 16

    config.logger.info('for D, loaded, class nums: %d', config.num_classes)
    config.logger.info('steps_per_epoch: %d', config.steps_per_epoch)
    config.logger.info('img_total_num: %d',
                       config.steps_per_epoch * config.per_batch_size)

    config.logger.info('get_backbone----in----')
    _backbone = get_backbone(config)
    config.logger.info('get_backbone----out----')
    config.logger.info('get_metric_fc----in----')
    margin_fc_1 = get_metric_fc(config)
    config.logger.info('get_metric_fc----out----')
    config.logger.info('DistributedHelper----in----')
    network_1 = DistributedHelper(_backbone, margin_fc_1)
    config.logger.info('DistributedHelper----out----')
    config.logger.info('network fp16----in----')
    if config.fp16 == 1:
        network_1.add_flags_recursive(fp16=True)
    config.logger.info('network fp16----out----')

    criterion_1 = get_loss(config)
    if config.fp16 == 1 and config.model_parallel == 0:
        criterion_1.add_flags_recursive(fp32=True)

    network_1 = load_pretrain(config, network_1)
    train_net = BuildTrainNetwork(network_1, criterion_1, config)

    # call warmup_step should behind the config steps_per_epoch
    config.lrs = warmup_step_list(config, gamma=0.1)
    lrs_gen = list_to_gen(config.lrs)
    opt = Momentum(params=train_net.trainable_params(),
                   learning_rate=lrs_gen,
                   momentum=config.momentum,
                   weight_decay=config.weight_decay)
    scale_manager = DynamicLossScaleManager(
        init_loss_scale=config.dynamic_init_loss_scale,
        scale_factor=2,
        scale_window=2000)
    model = Model(train_net,
                  optimizer=opt,
                  metrics=None,
                  loss_scale_manager=scale_manager)

    save_checkpoint_steps = config.ckpt_steps
    config.logger.info('save_checkpoint_steps: %d', save_checkpoint_steps)
    if config.max_ckpts == -1:
        keep_checkpoint_max = int(config.steps_per_epoch * config.max_epoch /
                                  save_checkpoint_steps) + 5
    else:
        keep_checkpoint_max = config.max_ckpts
    config.logger.info('keep_checkpoint_max: %d', keep_checkpoint_max)

    ckpt_config = CheckpointConfig(save_checkpoint_steps=save_checkpoint_steps,
                                   keep_checkpoint_max=keep_checkpoint_max)
    config.logger.info('max_epoch_train: %d', config.max_epoch)
    ckpt_cb = ModelCheckpoint(config=ckpt_config,
                              directory=config.ckpt_path,
                              prefix='{}'.format(config.local_rank))
    config.epoch_cnt = 0
    progress_cb = ProgressMonitor(config)
    new_epoch_train = config.max_epoch * steps_per_epoch // config.log_interval
    model.train(new_epoch_train,
                de_dataset,
                callbacks=[progress_cb, ckpt_cb],
                sink_size=config.log_interval)
示例#4
0
文件: train.py 项目: yrpang/mindspore
        else:
            if args.num_classes % (args.world_size * 16) == 0:
                args.logger.info('model parallel aleardy 16, nums: {}'.format(
                    args.num_classes))
            else:
                args.num_classes = (args.num_classes //
                                    (args.world_size * 16) +
                                    1) * args.world_size * 16

    args.logger.info('for D, loaded, class nums: {}'.format(args.num_classes))
    args.logger.info('steps_per_epoch:{}'.format(args.steps_per_epoch))
    args.logger.info('img_total_num:{}'.format(args.steps_per_epoch *
                                               args.per_batch_size))

    args.logger.info('get_backbone----in----')
    _backbone = get_backbone(args)
    args.logger.info('get_backbone----out----')

    args.logger.info('get_metric_fc----in----')
    margin_fc_1 = get_metric_fc(args)
    args.logger.info('get_metric_fc----out----')

    args.logger.info('DistributedHelper----in----')
    network_1 = DistributedHelper(_backbone, margin_fc_1)
    args.logger.info('DistributedHelper----out----')

    args.logger.info('network fp16----in----')
    if args.fp16 == 1:
        network_1.add_flags_recursive(fp16=True)
    args.logger.info('network fp16----out----')