def test_lars(): class Net(nn.Cell): def __init__(self, strategy1, strategy2, weight): super().__init__() self.weight = Parameter(weight, "w1") self.matmul = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1) self.relu = P.ReLU().shard(strategy2) def construct(self, x): out = self.matmul(x, self.weight) out = self.relu(out) return out context.set_auto_parallel_context(device_num=4, global_rank=0) context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") strategy1 = ((2, 1), (2, 1)) strategy2 = ((4, 1),) strategy3 = ((4, 1), (4, 1)) x = Tensor(np.ones([64, 32]), dtype=ms.float32) weight = Tensor(np.ones([64, 32]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32) net = Net(strategy1, strategy2, weight) lr = Tensor(np.ones([6]), dtype=ms.float32) sgd = Momentum(net.trainable_params(), lr, 0.9) optimizer = LARS(sgd, epsilon=1e-08, coefficient=0.02, lars_filter=lambda x: 'bn' not in x.name) net_with_loss = NetWithLoss(net, strategy3) train_net = TrainOneStepCell(net_with_loss, optimizer) compile_net(train_net, x, b)
def test_lars(): inputs = Tensor(np.ones([1, 64]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32)) net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() lr = multisteplr(10, [2, 6]) SGD = Momentum(net.trainable_params(), lr, 0.9) optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name, lars_filter=lambda x: 'bn' not in x.name) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label)
}, { 'params': no_decayed_params }, { 'order_params': net.trainable_params() }] if args.use_lars: sgd = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, args.momentum, use_nesterov=args.use_nesterov) opt = LARS(sgd, epsilon=args.lars_epsilon, hyperpara=args.lars_coefficient, weight_decay=args.weight_decay, decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, lars_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, loss_scale=args.loss_scale) else: opt = Momentum(group_params, lr, args.momentum, weight_decay=args.weight_decay, loss_scale=args.loss_scale, use_nesterov=args.use_nesterov) # model model = Model(net, loss_fn=loss,