示例#1
0
def train_on_ascend():
    config = config_ascend_quant
    print("training args: {}".format(args_opt))
    print("training configure: {}".format(config))
    print("parallel args: rank_id {}, device_id {}, rank_size {}".format(
        rank_id, device_id, rank_size))
    epoch_size = config.epoch_size

    # distribute init
    if run_distribute:
        context.set_auto_parallel_context(
            device_num=rank_size,
            parallel_mode=ParallelMode.DATA_PARALLEL,
            parameter_broadcast=True,
            mirror_mean=True)
        init()

    # define network
    network = mobilenetV2(num_classes=config.num_classes)
    # define loss
    if config.label_smooth > 0:
        loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth,
                                           num_classes=config.num_classes)
    else:
        loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False,
                                                sparse=True,
                                                reduction='mean')
    # define dataset
    dataset = create_dataset(dataset_path=args_opt.dataset_path,
                             do_train=True,
                             config=config,
                             device_target=args_opt.device_target,
                             repeat_num=1,
                             batch_size=config.batch_size)
    step_size = dataset.get_dataset_size()
    # load pre trained ckpt
    if args_opt.pre_trained:
        param_dict = load_checkpoint(args_opt.pre_trained)
        _load_param_into_net(network, param_dict)
    # convert fusion network to quantization aware network
    network = quant.convert_quant_network(network,
                                          bn_fold=True,
                                          per_channel=[True, False],
                                          symmetric=[True, False])

    # get learning rate
    lr = Tensor(
        get_lr(global_step=config.start_epoch * step_size,
               lr_init=0,
               lr_end=0,
               lr_max=config.lr,
               warmup_epochs=config.warmup_epochs,
               total_epochs=epoch_size + config.start_epoch,
               steps_per_epoch=step_size))

    # define optimization
    opt = nn.Momentum(
        filter(lambda x: x.requires_grad, network.get_parameters()), lr,
        config.momentum, config.weight_decay)
    # define model
    model = Model(network, loss_fn=loss, optimizer=opt)

    print("============== Starting Training ==============")
    callback = None
    if rank_id == 0:
        callback = [Monitor(lr_init=lr.asnumpy())]
        if config.save_checkpoint:
            config_ck = CheckpointConfig(
                save_checkpoint_steps=config.save_checkpoint_epochs *
                step_size,
                keep_checkpoint_max=config.keep_checkpoint_max)
            ckpt_cb = ModelCheckpoint(prefix="mobilenetV2",
                                      directory=config.save_checkpoint_path,
                                      config=config_ck)
            callback += [ckpt_cb]
    model.train(epoch_size, dataset, callbacks=callback)
    print("============== End Training ==============")
示例#2
0
            mirror_mean=True)
        init()
        context.set_auto_parallel_context(
            device_num=args_opt.device_num,
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True)
        auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])

    # define network
    net = resnet50_quant(class_num=config.class_num)
    net.set_train(True)

    # weight init and load checkpoint file
    if args_opt.pre_trained:
        param_dict = load_checkpoint(args_opt.pre_trained)
        _load_param_into_net(net, param_dict)
        epoch_size = config.epoch_size - config.pretrained_epoch_size
    else:
        for _, cell in net.cells_and_names():
            if isinstance(cell, nn.Conv2d):
                cell.weight.default_input = weight_init.initializer(
                    weight_init.XavierUniform(), cell.weight.shape,
                    cell.weight.dtype)
            if isinstance(cell, nn.Dense):
                cell.weight.default_input = weight_init.initializer(
                    weight_init.TruncatedNormal(), cell.weight.shape,
                    cell.weight.dtype)
    if not config.use_label_smooth:
        config.label_smooth_factor = 0.0
    loss = CrossEntropy(smooth_factor=config.label_smooth_factor,
                        num_classes=config.class_num)
示例#3
0
def train_on_gpu():
    config = config_gpu_quant
    print("training args: {}".format(args_opt))
    print("training configure: {}".format(config))

    # define network
    network = mobilenetV2(num_classes=config.num_classes)
    # define loss
    if config.label_smooth > 0:
        loss = CrossEntropyWithLabelSmooth(smooth_factor=config.label_smooth,
                                           num_classes=config.num_classes)
    else:
        loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False,
                                                sparse=True,
                                                reduction='mean')
    # define dataset
    epoch_size = config.epoch_size
    dataset = create_dataset(dataset_path=args_opt.dataset_path,
                             do_train=True,
                             config=config,
                             device_target=args_opt.device_target,
                             repeat_num=1,
                             batch_size=config.batch_size)
    step_size = dataset.get_dataset_size()
    # resume
    if args_opt.pre_trained:
        param_dict = load_checkpoint(args_opt.pre_trained)
        _load_param_into_net(network, param_dict)

    # convert fusion network to quantization aware network
    network = quant.convert_quant_network(network,
                                          bn_fold=True,
                                          per_channel=[True, False],
                                          symmetric=[True, False],
                                          freeze_bn=1000000,
                                          quant_delay=step_size * 2)

    # get learning rate
    loss_scale = FixedLossScaleManager(config.loss_scale,
                                       drop_overflow_update=False)
    lr = Tensor(
        get_lr(global_step=config.start_epoch * step_size,
               lr_init=0,
               lr_end=0,
               lr_max=config.lr,
               warmup_epochs=config.warmup_epochs,
               total_epochs=epoch_size + config.start_epoch,
               steps_per_epoch=step_size))

    # define optimization
    opt = nn.Momentum(
        filter(lambda x: x.requires_grad, network.get_parameters()), lr,
        config.momentum, config.weight_decay, config.loss_scale)
    # define model
    model = Model(network,
                  loss_fn=loss,
                  optimizer=opt,
                  loss_scale_manager=loss_scale)

    print("============== Starting Training ==============")
    callback = [Monitor(lr_init=lr.asnumpy())]
    ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(
        get_rank()) + "/"
    if config.save_checkpoint:
        config_ck = CheckpointConfig(
            save_checkpoint_steps=config.save_checkpoint_epochs * step_size,
            keep_checkpoint_max=config.keep_checkpoint_max)
        ckpt_cb = ModelCheckpoint(prefix="mobilenetV2",
                                  directory=ckpt_save_dir,
                                  config=config_ck)
        callback += [ckpt_cb]
    model.train(epoch_size, dataset, callbacks=callback)
    print("============== End Training ==============")