def main():
    path = os.path.join(
        ROOT_DIR, 'spikes', 'mnist', 'crop_locally_connected',
        'train_2_12_4_100_4_0.01_0.99_60000_250.0_250_1.0_0.05_1e-07_0.5_0.2_10_250'
    )

    ngram_scores = {}
    for i in tqdm(range(200, 240)):
        f = os.path.join(path, f'{i}.pt')
        spikes, labels = torch.load(f, map_location=location)
        ngram_scores = update_ngram_scores(spikes=spikes,
                                           labels=labels,
                                           n_labels=10,
                                           n=1,
                                           ngram_scores=ngram_scores)

    all_labels = torch.LongTensor()
    all_predictions = torch.LongTensor()
    for i in tqdm(range(200, 240)):
        f = os.path.join(path, f'{i}.pt')
        spikes, labels = torch.load(f, map_location=location)
        predictions = ngram(spikes=spikes,
                            ngram_scores=ngram_scores,
                            n_labels=10,
                            n=1)
        all_labels = torch.cat([all_labels, labels.long()])
        all_predictions = torch.cat([all_predictions, predictions.long()])

    accuracy = (all_labels == all_predictions).float().mean() * 100
    print(f'Training accuracy: {accuracy:.2f}')

    path = os.path.join(
        ROOT_DIR, 'spikes', 'mnist', 'crop_locally_connected',
        'test_2_12_4_100_4_0.01_0.99_60000_10000_250.0_250_1.0_0.05_1e-07_0.5_0.2_10_250'
    )

    all_labels = torch.LongTensor()
    all_predictions = torch.LongTensor()
    for i in tqdm(range(1, 40)):
        f = os.path.join(path, f'{i}.pt')
        spikes, labels = torch.load(f, map_location=location)
        predictions = ngram(spikes=spikes,
                            ngram_scores=ngram_scores,
                            n_labels=10,
                            n=1)
        all_labels = torch.cat([all_labels, labels.long()])
        all_predictions = torch.cat([all_predictions, predictions.long()])

    accuracy = (all_labels == all_predictions).float().mean() * 100
    print(f'Test accuracy: {accuracy:.2f}')
예제 #2
0
def update_curves(curves: Dict[str,
                               list], labels: torch.Tensor, n_classes: int,
                  **kwargs) -> Tuple[Dict[str, list], Dict[str, torch.Tensor]]:
    # language=rst
    """
    Updates accuracy curves for each classification scheme.

    :param curves: Mapping from name of classification scheme to list of accuracy evaluations.
    :param labels: One-dimensional ``torch.Tensor`` of integer data labels.
    :param n_classes: Number of data categories.
    :param kwargs: Additional keyword arguments for classification scheme evaluation functions.
    :return: Updated accuracy curves and predictions.
    """
    predictions = {}
    for scheme in curves:
        # Branch based on name of classification scheme
        if scheme == 'all':
            spike_record = kwargs['spike_record']
            assignments = kwargs['assignments']

            prediction = all_activity(spike_record, assignments, n_classes)
        elif scheme == 'proportion':
            spike_record = kwargs['spike_record']
            assignments = kwargs['assignments']
            proportions = kwargs['proportions']

            prediction = proportion_weighting(spike_record, assignments,
                                              proportions, n_classes)
        elif scheme == 'ngram':
            spike_record = kwargs['spike_record']
            ngram_scores = kwargs['ngram_scores']
            n = kwargs['n']

            prediction = ngram(spike_record, ngram_scores, n_classes, n)
        elif scheme == 'logreg':
            full_spike_record = kwargs['full_spike_record']
            logreg = kwargs['logreg']

            prediction = logreg_predict(spikes=full_spike_record,
                                        logreg=logreg)

        else:
            raise NotImplementedError

        # Compute accuracy with current classification scheme.
        predictions[scheme] = prediction
        accuracy = torch.sum(labels.long() == prediction).float() / len(labels)
        curves[scheme].append(100 * accuracy)

    return curves, predictions
예제 #3
0
    def batch_predictions(self, images):
        spike_record = torch.zeros(len(images), 250, self._model.layers['Y'].n)
        for i, image in enumerate(images):
            self._model.run(inpts={
                'X':
                poisson(datum=torch.tensor(image), time=250, dt=1)
            },
                            time=250,
                            dt=1)
            self._model.reset_()

            spike_record[i] = self._model.monitors['Y_spikes'].get('s').t()

        labels = ngram(spike_record, self._ngram_scores, self.num_classes(),
                       2).numpy()
        return labels
예제 #4
0
    def predictions(self, image):
        global axes, ims

        self._model.run(
            inpts={'X': poisson(datum=torch.tensor(image), time=250, dt=1)},
            time=250,
            dt=1)
        spike_record = self._model.monitors['Y_spikes'].get('s').t().unsqueeze(
            0)
        label = ngram(spike_record, self._ngram_scores, self.num_classes(),
                      2).numpy()[0]

        self._model.reset_()

        # axes, ims = plot_input(image.reshape(20, 20), image.reshape(20, 20), axes=axes, ims=ims)
        # plt.pause(0.05)

        print(label)

        return label
예제 #5
0
def main(seed=0, n_examples=100, gpu=False, plot=False):

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    model_name = '0_12_4_150_4_0.01_0.99_60000_250.0_250_1.0_0.05_1e-07_0.5_0.2_10_250'

    network = load_network(os.path.join(params_path, f'{model_name}.pt'))

    for l in network.layers:
        network.layers[l].dt = network.dt

    for c in network.connections:
        network.connections[c].dt = network.dt

    network.layers['Y'].one_spike = True
    network.layers['Y'].lbound = None

    kernel_size = 12
    side_length = 20
    n_filters = 150
    time = 250
    intensity = 0.5
    crop = 4
    conv_size = network.connections['X', 'Y'].conv_size
    locations = network.connections['X', 'Y'].locations
    conv_prod = int(np.prod(conv_size))
    n_neurons = n_filters * conv_prod
    n_classes = 10

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
    network.add_monitor(voltage_monitor, name='output_voltage')

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    images, labels = dataset.get_test()
    images *= intensity
    images = images[:, crop:-crop, crop:-crop]

    # Neuron assignments and spike proportions.
    path = os.path.join(params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
    assignments, proportions, rates, ngram_scores = torch.load(open(
        path, 'rb'))

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name=f'{layer}_spikes')

    # Train the network.
    print('\nBegin black box adversarial attack.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None
    inpt_ims = None
    inpt_axes = None

    max_iters = 25
    delta = 0.1
    epsilon = 0.1

    for i in range(n_examples):
        # Get next input sample.
        original = images[i % len(images)].contiguous().view(-1)
        label = labels[i % len(images)]

        # Check if the image is correctly classified.
        sample = poisson(datum=original, time=time)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        # Check for incorrect classification.
        s = spikes['Y'].get('s').view(1, n_neurons, time)
        prediction = ngram(spikes=s,
                           ngram_scores=ngram_scores,
                           n_labels=10,
                           n=2).item()

        if prediction != label:
            continue

        # Create adversarial example.
        adversarial = False
        while not adversarial:
            adv_example = 255 * torch.rand(original.size())
            sample = poisson(datum=adv_example, time=time)
            inpts = {'X': sample}

            # Run the network on the input.
            network.run(inpts=inpts, time=time)

            # Check for incorrect classification.
            s = spikes['Y'].get('s').view(1, n_neurons, time)
            prediction = ngram(spikes=s,
                               ngram_scores=ngram_scores,
                               n_labels=n_classes,
                               n=2).item()

            if prediction == label:
                adversarial = True

        j = 0
        current = original.clone()
        while j < max_iters:
            # Orthogonal perturbation.
            # perturb = orthogonal_perturbation(delta=delta, image=adv_example, target=original)
            # temp = adv_example + perturb

            # # Forward perturbation.
            # temp = temp.clone() + forward_perturbation(epsilon * get_diff(temp, original), temp, adv_example)

            # print(temp)

            perturbation = torch.randn(original.size())

            unnormed_source_direction = original - perturbation
            source_norm = torch.norm(unnormed_source_direction)
            source_direction = unnormed_source_direction / source_norm

            dot = torch.dot(perturbation, source_direction)
            perturbation -= dot * source_direction
            perturbation *= epsilon * source_norm / torch.norm(perturbation)

            D = 1 / np.sqrt(epsilon**2 + 1)
            direction = perturbation - unnormed_source_direction
            spherical_candidate = current + D * direction

            spherical_candidate = torch.clamp(spherical_candidate, 0, 255)

            new_source_direction = original - spherical_candidate
            new_source_direction_norm = torch.norm(new_source_direction)

            # length if spherical_candidate would be exactly on the sphere
            length = delta * source_norm

            # length including correction for deviation from sphere
            deviation = new_source_direction_norm - source_norm
            length += deviation

            # make sure the step size is positive
            length = max(0, length)

            # normalize the length
            length = length / new_source_direction_norm

            candidate = spherical_candidate + length * new_source_direction
            candidate = torch.clamp(candidate, 0, 255)

            sample = poisson(datum=candidate, time=time)
            inpts = {'X': sample}

            # Run the network on the input.
            network.run(inpts=inpts, time=time)

            # Check for incorrect classification.
            s = spikes['Y'].get('s').view(1, n_neurons, time)
            prediction = ngram(spikes=s,
                               ngram_scores=ngram_scores,
                               n_labels=10,
                               n=2).item()

            # Optionally plot various simulation information.
            if plot:
                _input = original.view(side_length, side_length)
                reconstruction = candidate.view(side_length, side_length)
                _spikes = {
                    'X': spikes['X'].get('s').view(side_length**2, time),
                    'Y': spikes['Y'].get('s').view(n_neurons, time)
                }
                w = network.connections['X', 'Y'].w

                spike_ims, spike_axes = plot_spikes(spikes=_spikes,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_locally_connected_weights(w,
                                                            n_filters,
                                                            kernel_size,
                                                            conv_size,
                                                            locations,
                                                            side_length,
                                                            im=weights_im)
                inpt_axes, inpt_ims = plot_input(_input,
                                                 reconstruction,
                                                 label=labels[i],
                                                 ims=inpt_ims,
                                                 axes=inpt_axes)

                plt.pause(1e-8)

            if prediction == label:
                print('Attack failed.')
            else:
                print('Attack succeeded.')
                adv_example = candidate

            j += 1

        network.reset_()  # Reset state variables.

    print('\nAdversarial attack complete.\n')
예제 #6
0
def main():
    #TEST

    # hyperparameters
    n_neurons = 100
    n_test = 10000
    inhib = 100
    time = 350
    dt = 1
    intensity = 0.25
    # extra args
    progress_interval = 10
    update_interval = 250
    plot = True
    seed = 0
    train = True
    gpu = False
    n_classes = 10
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    # TESTING
    assert n_test % update_interval == 0
    np.random.seed(seed)
    save_weights_fn = "plots_snn/weights/weights_test.png"
    save_performance_fn = "plots_snn/performance/performance_test.png"
    save_assaiments_fn = "plots_snn/assaiments/assaiments_test.png"
    # load network
    network = load('net_output.pt')  # here goes file with network to load
    network.train(False)

    # pull dataset
    data, targets = torch.load(
        'data/MNIST/TorchvisionDatasetWrapper/processed/test.pt')
    data = data * intensity
    data_stretched = data.view(len(data), -1, 784)
    testset = torch.utils.data.TensorDataset(data_stretched, targets)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=1,
                                             shuffle=True)
    # spike init
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_test, n_neurons).long()
    # load parameters
    assignments, proportions, rates, ngram_scores = torch.load(
        'parameters_output.pt')  # here goes file with parameters to load
    # accuracy initialization
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)
    print("Begin test.")
    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None
    i = 0
    current_labels = torch.zeros(update_interval)

    # test
    test_time = t.time()
    time1 = t.time()
    for sample, label in testloader:
        sample = sample.view(1, 1, 28, 28)
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_test} took {(t.time()-time1)*10000} s')
        if i % update_interval == 0 and i > 0:
            # update accuracy evaluation
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          ngram_scores=ngram_scores,
                                          n=2)
            print_results(curves)
            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)
        sample_enc = poisson(datum=sample, time=time, dt=dt)
        inpts = {'X': sample_enc}
        # Run the network on the input.
        network.run(inputs=inpts, time=time)
        retries = 0
        while spikes['Ae'].get('s').sum() < 1 and retries < 3:
            retries += 1
            sample = sample * 2
            inpts = {'X': poisson(datum=sample, time=time, dt=dt)}
            network.run(inputs=inpts, time=time)

        # Spikes reocrding
        spike_record[i % update_interval] = spikes['Ae'].get('s').view(
            time, n_neurons)
        full_spike_record[i] = spikes['Ae'].get('s').view(
            time, n_neurons).sum(0).long()
        if plot:
            _input = sample.view(28, 28)
            reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections[('X', 'Ae')].w
            square_assignments = get_square_assignments(assignments, n_sqrt)
            assigns_im = plot_assignments(square_assignments, im=assigns_im)
            if i % update_interval == 0:  # plot weights on every update interval
                square_weights = get_square_weights(
                    input_exc_weights.view(784, n_neurons), n_sqrt, 28)
                weights_im = plot_weights(square_weights, im=weights_im)
                [weights_im,
                 save_weights_fn] = plot_weights(square_weights,
                                                 im=weights_im,
                                                 save=save_weights_fn)
            inpt_axes, inpt_ims = plot_input(_input,
                                             reconstruction,
                                             label=label,
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            assigns_im = plot_assignments(square_assignments,
                                          im=assigns_im,
                                          save=save_assaiments_fn)
            perf_ax = plot_performance(curves,
                                       ax=perf_ax,
                                       save=save_performance_fn)
            plt.pause(1e-8)
        current_labels[i % update_interval] = label[0]
        network.reset_state_variables()
        if i % 10 == 0 and i > 0:
            preds = ngram(
                spike_record[i % update_interval - 10:i % update_interval],
                ngram_scores, n_classes, 2)
            print(f'Predictions: {(preds*1.0).numpy()}')
            print(
                f'True value:  {current_labels[i%update_interval-10:i%update_interval].numpy()}'
            )
        time1 = t.time()
        i += 1
        # Compute confusion matrices and save them to disk.
        confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(targets, predictions[scheme])
        to_write = 'confusion_test'
        f = '_'.join([str(x) for x in to_write]) + '.pt'
        torch.save(confusions, os.path.join('.', f))
    print("Test completed. Testing took " + str((t.time() - test_time) / 6) +
          " min.")
def main(seed=0, n_train=60000, n_test=10000, inhib=250, kernel_size=(16,), stride=(2,), n_filters=25, crop=4, lr=0.01,
         lr_decay=1, time=100, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=1, norm=0.2, progress_interval=10,
         update_interval=250, plot=False, train=True, gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
        'No. examples must be divisible by update_interval'

    params = [
        seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, inhib, time, dt,
        theta_plus, theta_decay, intensity, norm, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, n_test, inhib, time, dt,
            theta_plus, theta_decay, intensity, norm, progress_interval, update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    side_length = 28 - crop * 2
    n_examples = n_train if train else n_test

    network = load_network(os.path.join(params_path, model_name + '.pt'))

    network.layers['X'] = Input(n=400)
    network.layers['Y'] = DiehlAndCookNodes(
        n=network.layers['Y'].n, thresh=network.layers['Y'].thresh, rest=network.layers['Y'].rest,
        reset=network.layers['Y'].reset, theta_plus=network.layers['Y'].theta_plus,
        theta_decay=network.layers['Y'].theta_decay
    )

    network.add_layer(network.layers['X'], name='X')
    network.add_layer(network.layers['Y'], name='Y')

    network.connections['X', 'Y'].source = network.layers['X']
    network.connections['X', 'Y'].target = network.layers['Y']

    network.connections['X', 'Y'].update_rule = NoOp(
        connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu
    )
    network.layers['Y'].theta_decay = 0
    network.layers['Y'].theta_plus = 0

    conv_size = network.connections['X', 'Y'].conv_size
    locations = network.connections['X', 'Y'].locations
    conv_prod = int(np.prod(conv_size))
    n_neurons = n_filters * conv_prod
    n_classes = 10

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
    network.add_monitor(voltage_monitor, name='output_voltage')

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    images, labels = dataset.get_test()
    images *= intensity
    images = images[:, crop:-crop, crop:-crop]

    # Neuron assignments and spike proportions.
    path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt')
    assignments, proportions, rates, ngram_scores = torch.load(open(path, 'rb'))

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time)
        network.add_monitor(spikes[layer], name=f'{layer}_spikes')

    # Train the network.
    print('\nBegin black box adversarial attack.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None
    inpt_ims = None
    inpt_axes = None

    max_iters = 25
    delta = 0.1
    epsilon = 0.1

    for i in range(n_examples):
        # Get next input sample.
        original = images[i % len(images)].contiguous().view(-1)
        label = labels[i % len(images)]

        # Check if the image is correctly classified.
        sample = poisson(datum=original, time=time)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        # Check for incorrect classification.
        s = spikes['Y'].get('s').view(1, n_neurons, time)
        prediction = ngram(spikes=s, ngram_scores=ngram_scores, n_labels=10, n=2).item()

        if prediction != label:
            continue

        # Create adversarial example.
        adversarial = False
        while not adversarial:
            adv_example = 255 * torch.rand(original.size())
            sample = poisson(datum=adv_example, time=time)
            inpts = {'X': sample}

            # Run the network on the input.
            network.run(inpts=inpts, time=time)

            # Check for incorrect classification.
            s = spikes['Y'].get('s').view(1, n_neurons, time)
            prediction = ngram(spikes=s, ngram_scores=ngram_scores, n_labels=n_classes, n=2).item()

            if prediction == label:
                adversarial = True

        j = 0
        current = original.clone()
        while j < max_iters:
            # Orthogonal perturbation.
            # perturb = orthogonal_perturbation(delta=delta, image=adv_example, target=original)
            # temp = adv_example + perturb

            # # Forward perturbation.
            # temp = temp.clone() + forward_perturbation(epsilon * get_diff(temp, original), temp, adv_example)

            # print(temp)

            perturbation = torch.randn(original.size())

            unnormed_source_direction = original - perturbation
            source_norm = torch.norm(unnormed_source_direction)
            source_direction = unnormed_source_direction / source_norm

            dot = torch.dot(perturbation, source_direction)
            perturbation -= dot * source_direction
            perturbation *= epsilon * source_norm / torch.norm(perturbation)

            D = 1 / np.sqrt(epsilon ** 2 + 1)
            direction = perturbation - unnormed_source_direction
            spherical_candidate = current + D * direction

            spherical_candidate = torch.clamp(spherical_candidate, 0, 255)

            new_source_direction = original - spherical_candidate
            new_source_direction_norm = torch.norm(new_source_direction)

            # length if spherical_candidate would be exactly on the sphere
            length = delta * source_norm

            # length including correction for deviation from sphere
            deviation = new_source_direction_norm - source_norm
            length += deviation

            # make sure the step size is positive
            length = max(0, length)

            # normalize the length
            length = length / new_source_direction_norm

            candidate = spherical_candidate + length * new_source_direction
            candidate = torch.clamp(candidate, 0, 255)

            sample = poisson(datum=candidate, time=time)
            inpts = {'X': sample}

            # Run the network on the input.
            network.run(inpts=inpts, time=time)

            # Check for incorrect classification.
            s = spikes['Y'].get('s').view(1, n_neurons, time)
            prediction = ngram(spikes=s, ngram_scores=ngram_scores, n_labels=10, n=2).item()

            # Optionally plot various simulation information.
            if plot:
                _input = original.view(side_length, side_length)
                reconstruction = candidate.view(side_length, side_length)
                _spikes = {
                    'X': spikes['X'].get('s').view(side_length ** 2, time),
                    'Y': spikes['Y'].get('s').view(n_neurons, time)
                }
                w = network.connections['X', 'Y'].w

                spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes)
                weights_im = plot_locally_connected_weights(
                    w, n_filters, kernel_size, conv_size, locations, side_length, im=weights_im
                )
                inpt_axes, inpt_ims = plot_input(
                    _input, reconstruction, label=labels[i], ims=inpt_ims, axes=inpt_axes
                )

                plt.pause(1e-8)

            if prediction == label:
                print('Attack failed.')
            else:
                print('Attack succeeded.')
                adv_example = candidate

            j += 1

        network.reset_()  # Reset state variables.

    print('\nAdversarial attack complete.\n')
    print('\nBegin training.\n')
else:
    print('\nBegin test.\n')

start = t()
for i in range(n_examples):
    if i % progress_interval == 0:
        print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
        start = t()

    if i % update_interval == 0 and i > 0:
        # Get network predictions.
        all_activity_pred = all_activity(spike_record, assignments, 10)
        proportion_pred = proportion_weighting(spike_record, assignments,
                                               proportions, 10)
        ngram_pred = ngram(spike_record, ngram_scores, 10, 2)

        # Compute network accuracy according to available classification strategies.
        curves['all'].append(100 * torch.sum(labels[i - update_interval:i].long() \
                                                == all_activity_pred) / update_interval)
        curves['proportion'].append(100 * torch.sum(labels[i - update_interval:i].long() \
                                                        == proportion_pred) / update_interval)
        curves['ngram'].append(100 * torch.sum(labels[i - update_interval:i].long() \
                                                        == ngram_pred) / update_interval)

        print_results(curves)

        if train:
            if any([x[-1] > best_accuracy for x in curves.values()]):
                print('New best accuracy! Saving network parameters to disk.')