def test_adamweightdecay_group(): """ test_adam_group_lr_and_weight_decay """ inputs = Tensor(np.ones([1, 64]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32)) net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() net_with_loss = WithLossCell(net, loss) all_params = net.trainable_params() schedule_lr = nn.PolynomialDecayLR(0.01, 0.0001, 3, power=1.0) group_params = [{'params': [all_params[0]], 'lr': 0.02, 'weight_decay': 0.9}, {'params': [all_params[1]]}] optimizer = nn.AdamWeightDecay(group_params, learning_rate=schedule_lr) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label)
output = net(train_x) log_output = P.LogSoftmax(axis=1)(output) acc = np.mean(log_output.asnumpy().argmax(axis=1) == label.asnumpy()) accs.append(acc) acc_mean = np.mean(accs) return acc_mean if __name__ == "__main__": network = LeNet5() criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") optimizer = nn.AdamWeightDecay(params=network.trainable_params(), learning_rate=0.0001) net_with_loss = WithLossCell(network, criterion) train_network = TrainOneStepCell(net_with_loss, optimizer) bnn_transformer = transforms.TransformToBNN(train_network, 60000, 0.000001) train_bnn_network = bnn_transformer.transform_to_bnn_model() train_bnn_network.set_train() train_set = create_dataset( '/home/workspace/mindspore_dataset/mnist_data/train', 64, 1) test_set = create_dataset( '/home/workspace/mindspore_dataset/mnist_data/test', 64, 1) epoch = 500
def run_train_pipeline(args_opt): device_id = int(os.getenv("DEVICE_ID")) rank_id = int(os.getenv("RANK_ID")) local_rank = rank_id print('local_rank:{}, device id:{} start to run...'.format( local_rank, device_id), flush=True) context.set_context(save_graphs=False, mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id) context.set_context(variable_memory_max_size="31GB") strategy_ckpt_save_file = "/cache/" + "strategy" + str( local_rank) + ".ckpt" if args_opt.distribute == "true": D.init() device_num = D.get_group_size() rank = D.get_rank() print("device_id is {}, rank_id is {}, device_num is {}".format( device_id, rank, device_num)) context.reset_auto_parallel_context() context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False, device_num=device_num, full_batch=True, loss_repeated_mean=True, enable_parallel_optimizer=bool(args_opt.optimizer_shard), pipeline_stages=args_opt.stage_num, strategy_ckpt_save_file=strategy_ckpt_save_file) set_algo_parameters(elementwise_op_strategy_follow=True) _set_multi_subgraphs() else: rank = 0 device_num = 1 model_parallel_num = args_opt.tensor_model_parallel_num stage_device_num = int(device_num / args_opt.stage_num) data_parallel_num = int(stage_device_num / model_parallel_num) per_batch_size = args_opt.per_batch_size batch_size = per_batch_size * data_parallel_num * args_opt.micro_size config = PANGUALPHAConfig(data_parallel_num=data_parallel_num, model_parallel_num=model_parallel_num, batch_size=batch_size, seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size, embedding_size=args_opt.embedding_size, num_layers=args_opt.num_layers, num_heads=args_opt.num_heads, expand_ratio=4, post_layernorm_residual=False, dropout_rate=0.1, compute_dtype=mstype.float16, use_past=False, self_layernorm=True, forward_reduce_scatter=True, stage_num=args_opt.stage_num, micro_size=args_opt.micro_size, word_emb_dp=False) print("===config is: ", config, flush=True) pangu_alpha = PANGUALPHAPipeline(config) loss = CrossEntropyLoss(config) pangu_alpha_with_loss = PANGUALPHAWithLossPipeline(config, pangu_alpha, loss) pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss) print("=====args_opt is: ", args_opt, flush=True) lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr, warmup_steps=args_opt.warmup_step, decay_steps=args_opt.decay_steps) per_stage_layers = config.num_layers // config.stage_num per_stage_devices = device_num // config.stage_num self_stage = rank_id // per_stage_devices range_min = self_stage * per_stage_layers range_max = range_min + per_stage_layers if self_stage == 0: params = [pangu_alpha.embedding_table] params.extend(pangu_alpha.backbone.pangu_alpha_embedding. position_embedding.trainable_params()) elif self_stage == config.stage_num - 1: params = [pangu_alpha.embedding_table] params.extend(pangu_alpha.backbone.layernorm.trainable_params()) params.extend( pangu_alpha.backbone.top_query_embedding.trainable_params()) else: params = [] for i in range(range_min, range_max): params.extend(pangu_alpha.backbone.blocks[i].trainable_params()) decay_filter = lambda x: 'layernorm' not in x.name.lower( ) and "bias" not in x.name.lower() decay_params = list(filter(decay_filter, params)) other_params = list(filter(lambda x: not decay_filter(x), params)) group_params = [{ 'params': decay_params, 'weight_decay': args_opt.weight_decay }, { 'params': other_params, 'weight_decay': 0.0 }, { 'order_params': params }] if args_opt.optimizer == "lamb": optimizer = nn.Lamb(group_params, learning_rate=lr) else: optimizer = nn.AdamWeightDecay(group_params, learning_rate=lr, beta1=0.9, beta2=0.95, eps=1e-8) save_steps = args_opt.save_steps ckpt_dir = os.path.join(args_opt.ckpt_save_sir, f"rank_{str(local_rank)}") if not os.path.exists(ckpt_dir): Path(ckpt_dir).mkdir(parents=True, exist_ok=True) ds = create_dataset(config.batch_size, data_path=args_opt.data_url, data_start_index=0) epoch_num = args_opt.epoch_size step_per_epoch = ds.get_dataset_size() callback_size = args_opt.sink_size actual_epoch_num = int(epoch_num * step_per_epoch / callback_size) callback = [ TimeMonitor(callback_size), LossCallBack(callback_size, local_rank, config.stage_num) ] config_ck = CheckpointConfig(save_checkpoint_steps=save_steps, keep_checkpoint_max=1, integrated_save=False, filter_prefix="accu_grads") ckpoint_cb = ModelCheckpoint(prefix="PanguAlpha", directory=ckpt_dir, config=config_ck) callback.append(ckpoint_cb) loss_scale_value = math.pow(2, 32) update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000) pangu_alpha_with_grads = PANGUALPHATrainPipelineWithLossScaleCell( pangu_alpha_with_loss, optimizer=optimizer, config=config, scale_update_cell=update_cell) model = Model(pangu_alpha_with_grads) de.config.set_sending_batches(2 * args_opt.sink_size) model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True)