Esempio n. 1
0
def trainHardTest(optCls, dtype, *args, **kwargs):
    from PuzzleLib.Containers.Sequential import Sequential

    from PuzzleLib.Modules.Conv2D import Conv2D
    from PuzzleLib.Modules.BatchNorm2D import BatchNorm2D
    from PuzzleLib.Modules.Activation import Activation, relu
    from PuzzleLib.Modules.Cast import Cast

    from PuzzleLib.Cost.MSE import MSE

    seq = Sequential()

    seq.append(Conv2D(4, 8, 5, pad=1))
    seq.append(BatchNorm2D(8))
    seq.append(Activation(relu))

    seq.append(Conv2D(8, 16, 5, pad=1))

    seq.calcMode(dtype)
    seq.append(Cast(intype=dtype, outtype=np.float32))

    optimizer = optCls(*args, **kwargs)
    optimizer.setupOn(seq, useGlobalState=True)

    mse = MSE()

    data = gpuarray.to_gpu(np.random.randn(4, 4, 5, 5).astype(dtype))
    target = gpuarray.to_gpu(np.random.randn(4, 16, 1, 1).astype(np.float32))

    for i in range(200):
        error, grad = mse(seq(data), target)

        optimizer.zeroGradParams()
        seq.backward(grad)
        optimizer.update()

        if (i + 1) % 5 == 0:
            print("Iteration #%d error: %s" % (i + 1, error))
Esempio n. 2
0
def trainSimpleTest(optCls, dtype, *args, **kwargs):
    from PuzzleLib.Containers.Sequential import Sequential

    from PuzzleLib.Modules.Linear import Linear
    from PuzzleLib.Modules.Activation import Activation, relu
    from PuzzleLib.Modules.Cast import Cast

    from PuzzleLib.Cost.MSE import MSE

    seq = Sequential()

    seq.append(Linear(128, 64, useBias=False))
    seq.append(Activation(relu))
    seq.append(Linear(64, 32, useBias=False))
    seq.append(Activation(relu))
    seq.append(Linear(32, 16))

    seq.calcMode(dtype)
    seq.append(Cast(intype=dtype, outtype=np.float32))

    optimizer = optCls(*args, **kwargs)
    optimizer.setupOn(seq, useGlobalState=True)

    mse = MSE()

    data = gpuarray.to_gpu(np.random.randn(16, 128).astype(dtype))
    target = gpuarray.to_gpu(np.random.randn(16, 16).astype(np.float32))

    for i in range(200):
        error, grad = mse(seq(data), target)

        optimizer.zeroGradParams()
        seq.backward(grad)
        optimizer.update()

        if (i + 1) % 5 == 0:
            print("Iteration #%d error: %s" % (i + 1, error))