def _get_optimizer(args_opt, network):
    """get bert optimizer, support Lamb, Momentum, AdamWeightDecay."""
    if cfg.optimizer == 'Lamb':
        lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate,
                                       end_learning_rate=cfg.Lamb.end_learning_rate,
                                       warmup_steps=cfg.Lamb.warmup_steps,
                                       decay_steps=args_opt.train_steps,
                                       power=cfg.Lamb.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.Lamb.decay_filter, params))
        other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params))
        group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay},
                        {'params': other_params},
                        {'order_params': params}]
        optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps)
    elif cfg.optimizer == 'Momentum':
        optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
                             momentum=cfg.Momentum.momentum)
    elif cfg.optimizer == 'AdamWeightDecay':
        lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate,
                                       end_learning_rate=cfg.AdamWeightDecay.end_learning_rate,
                                       warmup_steps=cfg.AdamWeightDecay.warmup_steps,
                                       decay_steps=args_opt.train_steps,
                                       power=cfg.AdamWeightDecay.power)
        params = network.trainable_params()
        decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params))
        other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params))
        group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay},
                        {'params': other_params, 'weight_decay': 0.0},
                        {'order_params': params}]
        if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU':
            optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
        elif context.get_context("mode") == context.PYNATIVE_MODE and args_opt.device_target == 'GPU':
            optimizer = AdamWeightDecayOp(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
        else:
            optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps)
    elif cfg.optimizer == "Thor":
        from src.utils import get_bert_thor_lr, get_bert_thor_damping
        lr = get_bert_thor_lr(cfg.Thor.lr_max, cfg.Thor.lr_min, cfg.Thor.lr_power, cfg.Thor.lr_total_steps)
        damping = get_bert_thor_damping(cfg.Thor.damping_max, cfg.Thor.damping_min, cfg.Thor.damping_power,
                                        cfg.Thor.damping_total_steps)
        split_indices = None
        if bert_net_cfg.num_hidden_layers == 12:
            if bert_net_cfg.use_relative_positions:
                split_indices = [29, 58, 87, 116, 145, 174, 203, 217]
            else:
                split_indices = [28, 55, 82, 109, 136, 163, 190, 205]
        elif bert_net_cfg.num_hidden_layers == 24:
            if bert_net_cfg.use_relative_positions:
                split_indices = [30, 90, 150, 210, 270, 330, 390, 421]
            else:
                split_indices = [38, 93, 148, 203, 258, 313, 368, 397]
        optimizer = THOR(network, lr, damping, cfg.Thor.momentum,
                         cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size,
                         decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
                         split_indices=split_indices)
    else:
        raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]".
                         format(cfg.optimizer))
    return optimizer
Exemple #2
0
def train_process_bert_thor(q, device_id, epoch_size, device_num):
    os.system("mkdir " + str(device_id))
    os.chdir(str(device_id))
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
    context.set_context(reserve_class_name_in_scope=False)
    context.set_context(max_call_depth=3000)
    os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH
    os.environ['RANK_ID'] = str(device_id)
    os.environ['RANK_SIZE'] = str(device_num)

    D.init()
    rank = device_id % device_num
    context.reset_auto_parallel_context()
    _set_bert_all_reduce_split()
    context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
                                      device_num=device_num)

    data_set = create_bert_dataset(device_num=device_num, rank=rank, do_shuffle=False, data_dir=DATASET_PATH,
                                   schema_dir=None)
    net_with_loss = BertNetworkWithLoss(bert_net_cfg, True)

    new_repeat_count = epoch_size * data_set.get_dataset_size() // data_sink_steps
    new_repeat_count = min(new_repeat_count, train_steps // data_sink_steps)

    lr = get_bert_thor_lr()
    damping = get_bert_thor_damping()
    split_indices = [38, 77]
    optimizer = THOR(net_with_loss, lr, damping, momentum, weight_decay, loss_scale, batch_size,
                     decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(),
                     split_indices=split_indices)
    time_monitor_callback = TimeMonitor(data_sink_steps)
    loss_callback = LossCallback()
    callback = [time_monitor_callback, loss_callback]

    if load_checkpoint_path:
        param_dict = load_checkpoint(load_checkpoint_path)
        load_param_into_net(net_with_loss, param_dict)

    net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer)
    model = Model(net_with_grads)
    model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer,
                                                      frequency=frequency)
    model.train(new_repeat_count, data_set, callbacks=callback, dataset_sink_mode=True, sink_size=data_sink_steps)

    loss_list = loss_callback.loss_list
    per_step_mseconds = time_monitor_callback.per_step_mseconds_list
    q.put({'loss': loss_list, 'cost': per_step_mseconds})
Exemple #3
0
                          keep_batchnorm_fp32=False)
        else:
            ## fp32 training
            opt = Momentum(
                filter(lambda x: x.requires_grad, net.get_parameters()), lr,
                config.momentum, config.weight_decay)
            model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
    if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012":
        from src.lr_generator import get_thor_damping
        damping = get_thor_damping(0, config.damping_init,
                                   config.damping_decay, 70, step_size)
        split_indices = [26, 53]
        opt = THOR(net,
                   lr,
                   Tensor(damping),
                   config.momentum,
                   config.weight_decay,
                   config.loss_scale,
                   config.batch_size,
                   split_indices=split_indices)
        model = ConvertModelUtils().convert_to_thor_model(
            model=model,
            network=net,
            loss_fn=loss,
            optimizer=opt,
            loss_scale_manager=loss_scale,
            metrics={'acc'},
            amp_level="O2",
            keep_batchnorm_fp32=False,
            frequency=config.frequency)

    # define callbacks
def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl):
    os.system("mkdir " + str(device_id))
    os.chdir(str(device_id))
    context.set_context(mode=context.GRAPH_MODE,
                        device_target="Ascend",
                        save_graphs=False)
    context.set_context(device_id=device_id)
    os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH_2
    os.environ['RANK_ID'] = str(device_id - 4)
    os.environ['RANK_SIZE'] = str(device_num)
    if enable_hccl:
        context.set_auto_parallel_context(
            device_num=device_num,
            parallel_mode=ParallelMode.DATA_PARALLEL,
            gradients_mean=True,
            all_reduce_fusion_config=[85, 160])
        init()

    # network
    net = resnet50_thor(thor_config.class_num)

    if not thor_config.label_smooth:
        thor_config.label_smooth_factor = 0.0

    # loss
    loss = CrossEntropySmooth(sparse=True,
                              reduction="mean",
                              smooth_factor=thor_config.label_smooth_factor,
                              num_classes=thor_config.class_num)

    # train dataset
    dataset = create_dataset_thor(dataset_path=dataset_path,
                                  do_train=True,
                                  repeat_num=1,
                                  batch_size=thor_config.batch_size)

    step_size = dataset.get_dataset_size()
    eval_interval = thor_config.eval_interval

    # evaluation dataset
    eval_dataset = create_dataset(dataset_path=eval_path,
                                  do_train=False,
                                  repeat_num=1,
                                  batch_size=thor_config.eval_batch_size)

    # loss scale
    loss_scale = FixedLossScaleManager(thor_config.loss_scale,
                                       drop_overflow_update=False)

    # learning rate
    lr = get_thor_lr(0, 0.05803, 4.04839, 53, 5004, decay_epochs=39)
    damping = get_thor_damping(0, 0.02714, 0.50036, 70, 5004)
    # optimizer
    split_indices = [26, 53]
    opt = THOR(net,
               Tensor(lr),
               Tensor(damping),
               thor_config.momentum,
               thor_config.weight_decay,
               thor_config.loss_scale,
               thor_config.batch_size,
               split_indices=split_indices)

    # evaluation network
    dist_eval_network = ClassifyCorrectCell(net)
    # model
    model = THOR_Model(net,
                       loss_fn=loss,
                       optimizer=opt,
                       loss_scale_manager=loss_scale,
                       amp_level="O2",
                       keep_batchnorm_fp32=False,
                       metrics={
                           'acc':
                           DistAccuracy(batch_size=thor_config.eval_batch_size,
                                        device_num=device_num)
                       },
                       eval_network=dist_eval_network,
                       frequency=thor_config.frequency)

    # model init
    print("init_start", device_id)
    model.init(dataset, eval_dataset)
    print("init_stop", device_id)

    # callbacks
    loss_cb = LossGet(1, step_size)

    # train and eval
    acc = 0.0
    time_cost = 0.0
    print("run_start", device_id)
    for epoch_idx in range(0, int(epoch_size / eval_interval)):
        model.train(eval_interval, dataset, callbacks=loss_cb)
        eval_start = time.time()
        output = model.eval(eval_dataset)
        eval_cost = (time.time() - eval_start) * 1000
        acc = float(output["acc"])
        time_cost = loss_cb.get_per_step_time()
        loss = loss_cb.get_loss()
        print(
            "the {} epoch's resnet result:\n "
            "device{}, training loss {}, acc {}, "
            "training per step cost {:.2f} ms, eval cost {:.2f} ms, total_cost {:.2f} ms"
            .format(epoch_idx, device_id, loss, acc, time_cost, eval_cost,
                    time_cost * step_size + eval_cost))
    q.put({'acc': acc, 'cost': time_cost})