Exemplo n.º 1
0
def test_lars():
    class Net(nn.Cell):
        def __init__(self, strategy1, strategy2, weight):
            super().__init__()
            self.weight = Parameter(weight, "w1")
            self.matmul = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1)
            self.relu = P.ReLU().shard(strategy2)

        def construct(self, x):
            out = self.matmul(x, self.weight)
            out = self.relu(out)
            return out

    context.set_auto_parallel_context(device_num=4, global_rank=0)
    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
    strategy1 = ((2, 1), (2, 1))
    strategy2 = ((4, 1),)
    strategy3 = ((4, 1), (4, 1))

    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
    weight = Tensor(np.ones([64, 32]), dtype=ms.float32)
    b = Tensor(np.ones([64, 64]), dtype=ms.float32)

    net = Net(strategy1, strategy2, weight)

    lr = Tensor(np.ones([6]), dtype=ms.float32)
    sgd = Momentum(net.trainable_params(), lr, 0.9)
    optimizer = LARS(sgd, epsilon=1e-08, coefficient=0.02,
                     lars_filter=lambda x: 'bn' not in x.name)
    net_with_loss = NetWithLoss(net, strategy3)
    train_net = TrainOneStepCell(net_with_loss, optimizer)

    compile_net(train_net, x, b)
Exemplo n.º 2
0
def test_lars():
    inputs = Tensor(np.ones([1, 64]).astype(np.float32))
    label = Tensor(np.zeros([1, 10]).astype(np.float32))
    net = Net()
    net.set_train()
    loss = nn.SoftmaxCrossEntropyWithLogits()

    lr = multisteplr(10, [2, 6])
    SGD = Momentum(net.trainable_params(), lr, 0.9)
    optimizer = LARS(SGD,
                     epsilon=1e-08,
                     hyperpara=0.02,
                     decay_filter=lambda x: 'bn' not in x.name,
                     lars_filter=lambda x: 'bn' not in x.name)

    net_with_loss = WithLossCell(net, loss)
    train_network = TrainOneStepCell(net_with_loss, optimizer)
    _executor.compile(train_network, inputs, label)
Exemplo n.º 3
0
    }, {
        'params': no_decayed_params
    }, {
        'order_params': net.trainable_params()
    }]

    if args.use_lars:
        sgd = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
                       lr,
                       args.momentum,
                       use_nesterov=args.use_nesterov)
        opt = LARS(sgd,
                   epsilon=args.lars_epsilon,
                   hyperpara=args.lars_coefficient,
                   weight_decay=args.weight_decay,
                   decay_filter=lambda x: 'beta' not in x.name and 'gamma'
                   not in x.name and 'bias' not in x.name,
                   lars_filter=lambda x: 'beta' not in x.name and 'gamma'
                   not in x.name and 'bias' not in x.name,
                   loss_scale=args.loss_scale)
    else:
        opt = Momentum(group_params,
                       lr,
                       args.momentum,
                       weight_decay=args.weight_decay,
                       loss_scale=args.loss_scale,
                       use_nesterov=args.use_nesterov)

    # model
    model = Model(net,
                  loss_fn=loss,