netParams = snn.params('network.yaml')

    # Define the cuda device to run the code on.
    device = torch.device('cuda')
    # deviceIds = [2, 3]

    # Create network instance.
    net = Network(netParams).to(device)
    # net = torch.nn.DataParallel(Network(netParams).to(device), device_ids=deviceIds)

    # Create snn loss instance.
    error = snn.loss(netParams, spikeLayer).to(device)

    # Define optimizer module.
    # optimizer = torch.optim.Adam(net.parameters(), lr = 0.01, amsgrad = True)
    optimizer = optimizer.Nadam(net.parameters(), lr=0.01, amsgrad=True)

    # Dataset and dataLoader instances.
    trainingSet = IBMGestureDataset(
        datasetPath=netParams['training']['path']['in'],
        sampleFile=netParams['training']['path']['train'],
        samplingTime=netParams['simulation']['Ts'],
        sampleLength=netParams['simulation']['tSample'])
    trainLoader = DataLoader(dataset=trainingSet,
                             batch_size=4,
                             shuffle=True,
                             num_workers=1)

    testingSet = IBMGestureDataset(
        datasetPath=netParams['training']['path']['in'],
        sampleFile=netParams['training']['path']['test'],
Example #2
0
    netParams = snn.params('network.yaml')

    # Define the cuda device to run the code on.
    device = torch.device('cuda')
    # deviceIds = [1, 2]

    # Create network instance.
    net = Network(netParams, 'emg', 'dvsCropped').to(device)
    # net = torch.nn.DataParallel(Network(netParams).to(device), device_ids=deviceIds)

    # Create snn loss instance.
    error = snn.loss(netParams, snn.loihi).to(device)

    # Define optimizer module.
    optimizer = optim.Nadam(net.parameters(), lr=0.01)

    # Dataset and dataLoader instances.
    trainingSet = fusionDataset(
        samples=np.loadtxt('train.txt').astype(int),
        samplingTime=netParams['simulation']['Ts'],
        sampleLength=netParams['simulation']['tSample'],
        # sampleLength=2000,
    )

    testingSet = fusionDataset(
        samples=np.loadtxt('test.txt').astype(int),
        samplingTime=netParams['simulation']['Ts'],
        sampleLength=netParams['simulation']['tSample'],
    )