def _get_optimizer(args_opt, network): """get bert optimizer, support Lamb, Momentum, AdamWeightDecay.""" if cfg.optimizer == 'Lamb': lr_schedule = BertLearningRate(learning_rate=cfg.Lamb.learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate, warmup_steps=cfg.Lamb.warmup_steps, decay_steps=args_opt.train_steps, power=cfg.Lamb.power) params = network.trainable_params() decay_params = list(filter(cfg.Lamb.decay_filter, params)) other_params = list(filter(lambda x: not cfg.Lamb.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': cfg.Lamb.weight_decay}, {'params': other_params}, {'order_params': params}] optimizer = Lamb(group_params, learning_rate=lr_schedule, eps=cfg.Lamb.eps) elif cfg.optimizer == 'Momentum': optimizer = Momentum(network.trainable_params(), learning_rate=cfg.Momentum.learning_rate, momentum=cfg.Momentum.momentum) elif cfg.optimizer == 'AdamWeightDecay': lr_schedule = BertLearningRate(learning_rate=cfg.AdamWeightDecay.learning_rate, end_learning_rate=cfg.AdamWeightDecay.end_learning_rate, warmup_steps=cfg.AdamWeightDecay.warmup_steps, decay_steps=args_opt.train_steps, power=cfg.AdamWeightDecay.power) params = network.trainable_params() decay_params = list(filter(cfg.AdamWeightDecay.decay_filter, params)) other_params = list(filter(lambda x: not cfg.AdamWeightDecay.decay_filter(x), params)) group_params = [{'params': decay_params, 'weight_decay': cfg.AdamWeightDecay.weight_decay}, {'params': other_params, 'weight_decay': 0.0}, {'order_params': params}] if args_opt.enable_lossscale == "true" and args_opt.device_target == 'GPU': optimizer = AdamWeightDecayForBert(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) elif context.get_context("mode") == context.PYNATIVE_MODE and args_opt.device_target == 'GPU': optimizer = AdamWeightDecayOp(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) else: optimizer = AdamWeightDecay(group_params, learning_rate=lr_schedule, eps=cfg.AdamWeightDecay.eps) elif cfg.optimizer == "Thor": from src.utils import get_bert_thor_lr, get_bert_thor_damping lr = get_bert_thor_lr(cfg.Thor.lr_max, cfg.Thor.lr_min, cfg.Thor.lr_power, cfg.Thor.lr_total_steps) damping = get_bert_thor_damping(cfg.Thor.damping_max, cfg.Thor.damping_min, cfg.Thor.damping_power, cfg.Thor.damping_total_steps) split_indices = None if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.use_relative_positions: split_indices = [29, 58, 87, 116, 145, 174, 203, 217] else: split_indices = [28, 55, 82, 109, 136, 163, 190, 205] elif bert_net_cfg.num_hidden_layers == 24: if bert_net_cfg.use_relative_positions: split_indices = [30, 90, 150, 210, 270, 330, 390, 421] else: split_indices = [38, 93, 148, 203, 258, 313, 368, 397] optimizer = THOR(network, lr, damping, cfg.Thor.momentum, cfg.Thor.weight_decay, cfg.Thor.loss_scale, cfg.batch_size, decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), split_indices=split_indices) else: raise ValueError("Don't support optimizer {}, only support [Lamb, Momentum, AdamWeightDecay, Thor]". format(cfg.optimizer)) return optimizer
def train_process_bert_thor(q, device_id, epoch_size, device_num): os.system("mkdir " + str(device_id)) os.chdir(str(device_id)) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) context.set_context(reserve_class_name_in_scope=False) context.set_context(max_call_depth=3000) os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH os.environ['RANK_ID'] = str(device_id) os.environ['RANK_SIZE'] = str(device_num) D.init() rank = device_id % device_num context.reset_auto_parallel_context() _set_bert_all_reduce_split() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) data_set = create_bert_dataset(device_num=device_num, rank=rank, do_shuffle=False, data_dir=DATASET_PATH, schema_dir=None) net_with_loss = BertNetworkWithLoss(bert_net_cfg, True) new_repeat_count = epoch_size * data_set.get_dataset_size() // data_sink_steps new_repeat_count = min(new_repeat_count, train_steps // data_sink_steps) lr = get_bert_thor_lr() damping = get_bert_thor_damping() split_indices = [38, 77] optimizer = THOR(net_with_loss, lr, damping, momentum, weight_decay, loss_scale, batch_size, decay_filter=lambda x: 'layernorm' not in x.name.lower() and 'bias' not in x.name.lower(), split_indices=split_indices) time_monitor_callback = TimeMonitor(data_sink_steps) loss_callback = LossCallback() callback = [time_monitor_callback, loss_callback] if load_checkpoint_path: param_dict = load_checkpoint(load_checkpoint_path) load_param_into_net(net_with_loss, param_dict) net_with_grads = BertTrainOneStepCell(net_with_loss, optimizer=optimizer) model = Model(net_with_grads) model = ConvertModelUtils().convert_to_thor_model(model, network=net_with_grads, optimizer=optimizer, frequency=frequency) model.train(new_repeat_count, data_set, callbacks=callback, dataset_sink_mode=True, sink_size=data_sink_steps) loss_list = loss_callback.loss_list per_step_mseconds = time_monitor_callback.per_step_mseconds_list q.put({'loss': loss_list, 'cost': per_step_mseconds})
keep_batchnorm_fp32=False) else: ## fp32 training opt = Momentum( filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay) model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) if cfg.optimizer == "Thor" and args_opt.dataset == "imagenet2012": from src.lr_generator import get_thor_damping damping = get_thor_damping(0, config.damping_init, config.damping_decay, 70, step_size) split_indices = [26, 53] opt = THOR(net, lr, Tensor(damping), config.momentum, config.weight_decay, config.loss_scale, config.batch_size, split_indices=split_indices) model = ConvertModelUtils().convert_to_thor_model( model=model, network=net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2", keep_batchnorm_fp32=False, frequency=config.frequency) # define callbacks
def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): os.system("mkdir " + str(device_id)) os.chdir(str(device_id)) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(device_id=device_id) os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH_2 os.environ['RANK_ID'] = str(device_id - 4) os.environ['RANK_SIZE'] = str(device_num) if enable_hccl: context.set_auto_parallel_context( device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, all_reduce_fusion_config=[85, 160]) init() # network net = resnet50_thor(thor_config.class_num) if not thor_config.label_smooth: thor_config.label_smooth_factor = 0.0 # loss loss = CrossEntropySmooth(sparse=True, reduction="mean", smooth_factor=thor_config.label_smooth_factor, num_classes=thor_config.class_num) # train dataset dataset = create_dataset_thor(dataset_path=dataset_path, do_train=True, repeat_num=1, batch_size=thor_config.batch_size) step_size = dataset.get_dataset_size() eval_interval = thor_config.eval_interval # evaluation dataset eval_dataset = create_dataset(dataset_path=eval_path, do_train=False, repeat_num=1, batch_size=thor_config.eval_batch_size) # loss scale loss_scale = FixedLossScaleManager(thor_config.loss_scale, drop_overflow_update=False) # learning rate lr = get_thor_lr(0, 0.05803, 4.04839, 53, 5004, decay_epochs=39) damping = get_thor_damping(0, 0.02714, 0.50036, 70, 5004) # optimizer split_indices = [26, 53] opt = THOR(net, Tensor(lr), Tensor(damping), thor_config.momentum, thor_config.weight_decay, thor_config.loss_scale, thor_config.batch_size, split_indices=split_indices) # evaluation network dist_eval_network = ClassifyCorrectCell(net) # model model = THOR_Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, amp_level="O2", keep_batchnorm_fp32=False, metrics={ 'acc': DistAccuracy(batch_size=thor_config.eval_batch_size, device_num=device_num) }, eval_network=dist_eval_network, frequency=thor_config.frequency) # model init print("init_start", device_id) model.init(dataset, eval_dataset) print("init_stop", device_id) # callbacks loss_cb = LossGet(1, step_size) # train and eval acc = 0.0 time_cost = 0.0 print("run_start", device_id) for epoch_idx in range(0, int(epoch_size / eval_interval)): model.train(eval_interval, dataset, callbacks=loss_cb) eval_start = time.time() output = model.eval(eval_dataset) eval_cost = (time.time() - eval_start) * 1000 acc = float(output["acc"]) time_cost = loss_cb.get_per_step_time() loss = loss_cb.get_loss() print( "the {} epoch's resnet result:\n " "device{}, training loss {}, acc {}, " "training per step cost {:.2f} ms, eval cost {:.2f} ms, total_cost {:.2f} ms" .format(epoch_idx, device_id, loss, acc, time_cost, eval_cost, time_cost * step_size + eval_cost)) q.put({'acc': acc, 'cost': time_cost})