示例#1
0
文件: train.py 项目: yrpang/mindspore
def run_transformer_train():
    """
    Transformer training.
    """
    parser = argparse_init()
    args, _ = parser.parse_known_args()
    if args.device_target == "Ascend":
        context.set_context(mode=context.GRAPH_MODE,
                            device_target=args.device_target,
                            device_id=args.device_id)
    else:
        context.set_context(mode=context.GRAPH_MODE,
                            device_target=args.device_target)
    context.set_context(reserve_class_name_in_scope=False,
                        enable_auto_mixed_precision=False)

    if args.distribute == "true":
        if args.device_target == "Ascend":
            device_num = args.device_num
            D.init('hccl')
        else:
            D.init('nccl')
            device_num = D.get_group_size()
            rank = get_rank()
            args.device_id = rank
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            device_num=device_num)
        rank_id = args.device_id % device_num
        save_ckpt_path = os.path.join(args.save_checkpoint_path,
                                      'ckpt_' + str(get_rank()) + '/')
    else:
        device_num = 1
        rank_id = 0
        save_ckpt_path = os.path.join(args.save_checkpoint_path, 'ckpt_0/')
    dataset = create_transformer_dataset(
        epoch_count=1,
        rank_size=device_num,
        rank_id=rank_id,
        do_shuffle=args.do_shuffle,
        dataset_path=args.data_path,
        bucket_boundaries=args.bucket_boundaries,
        device_target=args.device_target)
    if args.device_target == "Ascend":
        netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)
    else:
        netwithloss = TransformerNetworkWithLoss(transformer_net_cfg_gpu, True)

    if args.checkpoint_path:
        parameter_dict = load_checkpoint(args.checkpoint_path)
        load_param_into_net(netwithloss, parameter_dict)

    hidden_size = transformer_net_cfg.hidden_size if args.device_target == "Ascend" \
        else transformer_net_cfg_gpu.hidden_size
    learning_rate = cfg.lr_schedule.learning_rate if args.device_target == "Ascend" \
        else 1.0
    lr = Tensor(
        create_dynamic_lr(
            schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
            training_steps=dataset.get_dataset_size() * args.epoch_size,
            learning_rate=learning_rate,
            warmup_steps=cfg.lr_schedule.warmup_steps,
            hidden_size=hidden_size,
            start_decay_step=cfg.lr_schedule.start_decay_step,
            min_lr=cfg.lr_schedule.min_lr), mstype.float32)

    if args.device_target == "GPU" and cfg.transformer_network == "large":
        optimizer = Adam(netwithloss.trainable_params(),
                         lr,
                         beta2=cfg.optimizer_adam_beta2)
    else:
        optimizer = Adam(netwithloss.trainable_params(), lr)

    callbacks = [
        TimeMonitor(dataset.get_dataset_size()),
        LossCallBack(rank_id=rank_id)
    ]
    if args.enable_save_ckpt == "true":
        if device_num == 1 or (device_num > 1 and rank_id == 0):
            if args.device_target == "Ascend":
                ckpt_config = CheckpointConfig(
                    save_checkpoint_steps=args.save_checkpoint_steps,
                    keep_checkpoint_max=args.save_checkpoint_num)
            else:
                ckpt_config = CheckpointConfig(
                    save_checkpoint_steps=dataset.get_dataset_size(),
                    keep_checkpoint_max=args.save_checkpoint_num)
            ckpoint_cb = ModelCheckpoint(prefix='transformer',
                                         directory=save_ckpt_path,
                                         config=ckpt_config)
            callbacks.append(ckpoint_cb)

    if args.enable_lossscale == "true":
        scale_manager = DynamicLossScaleManager(
            init_loss_scale=cfg.init_loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        update_cell = scale_manager.get_update_cell()
        netwithgrads = TransformerTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = TransformerTrainOneStepCell(netwithloss,
                                                   optimizer=optimizer)

    netwithgrads.set_train(True)
    model = Model(netwithgrads)

    model.train(args.epoch_size,
                dataset,
                callbacks=callbacks,
                dataset_sink_mode=False)
示例#2
0
def run_transformer_train():
    """
    Transformer training.
    """
    parser = argparse_init()
    args, _ = parser.parse_known_args()
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        device_id=args.device_id)
    context.set_context(reserve_class_name_in_scope=False,
                        enable_auto_mixed_precision=False)

    if args.distribute == "true":
        device_num = args.device_num
        context.reset_auto_parallel_context()
        context.set_auto_parallel_context(
            parallel_mode=ParallelMode.DATA_PARALLEL,
            mirror_mean=True,
            parameter_broadcast=True,
            device_num=device_num)
        D.init()
        rank_id = args.device_id % device_num
    else:
        device_num = 1
        rank_id = 0
    dataset, repeat_count = create_transformer_dataset(
        epoch_count=args.epoch_size,
        rank_size=device_num,
        rank_id=rank_id,
        do_shuffle=args.do_shuffle,
        enable_data_sink=args.enable_data_sink,
        dataset_path=args.data_path)

    netwithloss = TransformerNetworkWithLoss(transformer_net_cfg, True)

    if args.checkpoint_path:
        parameter_dict = load_checkpoint(args.checkpoint_path)
        load_param_into_net(netwithloss, parameter_dict)

    lr = Tensor(
        create_dynamic_lr(
            schedule="constant*rsqrt_hidden*linear_warmup*rsqrt_decay",
            training_steps=dataset.get_dataset_size() * args.epoch_size,
            learning_rate=cfg.lr_schedule.learning_rate,
            warmup_steps=cfg.lr_schedule.warmup_steps,
            hidden_size=transformer_net_cfg.hidden_size,
            start_decay_step=cfg.lr_schedule.start_decay_step,
            min_lr=cfg.lr_schedule.min_lr), mstype.float32)
    optimizer = Adam(netwithloss.trainable_params(), lr)

    callbacks = [TimeMonitor(dataset.get_dataset_size()), LossCallBack()]
    if args.enable_save_ckpt == "true":
        ckpt_config = CheckpointConfig(
            save_checkpoint_steps=args.save_checkpoint_steps,
            keep_checkpoint_max=args.save_checkpoint_num)
        ckpoint_cb = ModelCheckpoint(prefix='transformer',
                                     directory=args.save_checkpoint_path,
                                     config=ckpt_config)
        callbacks.append(ckpoint_cb)

    if args.enable_lossscale == "true":
        scale_manager = DynamicLossScaleManager(
            init_loss_scale=cfg.init_loss_scale_value,
            scale_factor=cfg.scale_factor,
            scale_window=cfg.scale_window)
        update_cell = scale_manager.get_update_cell()
        netwithgrads = TransformerTrainOneStepWithLossScaleCell(
            netwithloss, optimizer=optimizer, scale_update_cell=update_cell)
    else:
        netwithgrads = TransformerTrainOneStepCell(netwithloss,
                                                   optimizer=optimizer)

    netwithgrads.set_train(True)
    model = Model(netwithgrads)
    model.train(repeat_count,
                dataset,
                callbacks=callbacks,
                dataset_sink_mode=(args.enable_data_sink == "true"))