def nearest_mean_classifier(X_train,
                            X_test,
                            y_train,
                            error=False,
                            y_test=None):
    """
    Calculates means for different classes given training data, and also predicts for test data
    :param X_train: ndarray (n, d) -- Design matrix
    :param X_test: ndarray (m, d) -- Test matrix
    :param y_train: ndarray (n, ) -- Outcome matrix of training data
    :return: tuple --
    predictions: ndarray (m, ) -- Vector of predictions
    means: ndarray (labels.shape[0], X_train.shape[1]) -- Means of different classes
    """
    labels = np.unique(y_train)
    means = np.empty(shape=(labels.shape[0], X_train.shape[1]))
    # Calculate mean of all labels
    for i, label in enumerate(labels):
        means[label, :] = np.mean(X_train[y_train == label], axis=0)

    dist = np.empty(shape=(X_test.shape[0], len(labels)))
    for i, label in enumerate(labels):
        dist[:, i] = euclidean_norm(X_test, means[i, :]).squeeze()

    predictions = np.argmin(dist, axis=-1)

    if error:
        return calculate_error(predictions, y_test)
    else:
        return predictions, means
Пример #2
0
    def train(self, *, train_loader, test_loader, num_epochs, device, **kwargs):
        print('SGDTrainer Hyperparameters:', json.dumps({'num_epochs': num_epochs, 'device': device}, indent=2))
        print('Unused kwargs:', kwargs)

        device = torch.device(device)
        metrics = []
        for epoch in range(1, num_epochs + 1):
            train_loss = 0
            for batch in train_loader:
                # Training Step
                data, label = (x.to(device) for x in batch)
                self.optimizer.zero_grad()
                prediction = self.model(data)
                loss = self.loss_fn(prediction, label)
                loss.backward()
                self.optimizer.step()
                # Update Statistics
                train_loss += loss.item() * data.shape[0]
            avg_train_loss = train_loss / len(train_loader.dataset)
            model_grad_norm = utils.calculate_full_gradient_norm(
                model=self.model, data_loader=train_loader, loss_fn=self.loss_fn, device=device)
            test_error = utils.calculate_error(model=self.model, data_loader=test_loader, device=device)
            metrics.append({
                'epoch': epoch,
                'train_loss': avg_train_loss,
                'grad_norm': model_grad_norm,
                'test_error': test_error
            })
            print('[Epoch {}] train_loss: {:.04f}, grad_norm: {:.02f}, test_error: {:.04f}'.format(
                epoch, avg_train_loss, model_grad_norm, test_error))
        return metrics
Пример #3
0
    def train(self, *, train_loader, test_loader, num_epochs, device, **kwargs):
        """Executes training for the model over the given dataset and hyperparameters.

        Inputs:
        train_loader: torch.utils.data.DataLoader data loader for the training dataset
        test_loader: torch.utils.data.DataLoader data loader for the test dataset
        num_epochs: the number of epochs to train for
        device: string denoting the device to run on. "cuda" or "cpu" are expected.
        kwargs: any additional keyword arguments will be excepted but ignored.

        Returns:
        metrics: a list of dictionaries containing information about the run, including the training loss,
            gradient norm, and test error for each epoch.
        """
        print('SGDTrainer Hyperparameters:', json.dumps({'num_epochs': num_epochs, 'device': device}, indent=2))
        print('Unused kwargs:', kwargs)

        device = torch.device(device)
        metrics = []
        for epoch in range(1, num_epochs + 1):
            train_loss = 0
            for batch in train_loader:
                # Training Step
                data, label = (x.to(device) for x in batch)
                self.optimizer.zero_grad()
                prediction = self.model(data)
                loss = self.loss_fn(prediction, label)
                loss.backward()
                self.optimizer.step()
                # Update Statistics
                train_loss += loss.item() * data.shape[0]
            avg_train_loss = train_loss / len(train_loader.dataset)
            model_grad_norm = utils.calculate_full_gradient_norm(
                model=self.model, data_loader=train_loader, loss_fn=self.loss_fn, device=device)
            test_error = utils.calculate_error(model=self.model, data_loader=test_loader, device=device)
            metrics.append({
                'epoch': epoch,
                'train_loss': avg_train_loss,
                'grad_norm': model_grad_norm,
                'test_error': test_error
            })
            print('[Epoch {}] train_loss: {:.04f}, grad_norm: {:.02f}, test_error: {:.04f}'.format(
                epoch, avg_train_loss, model_grad_norm, test_error))
        return metrics
Пример #4
0
    grey = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    src = cv2.Canny(grey, lowThreshold, lowThreshold * ratio, apertureSize=kernel_size)
    src_processed = cv2.morphologyEx(src, cv2.MORPH_CLOSE, dilate_kernel, iterations=1)
    crop_img = []
    img_sum_av = []
    img_sum_bi = []

    for i in range(1, 3):
        crop_y_start = int(((i-1) / 2) * size_y)
        crop_y_end = int((i / 2) * size_y)
        sliced_img = src_processed[crop_y_start:crop_y_end, 0:size_x]
        crop_img.append(sliced_img)
        img_sum_av.append(sig.savgol_filter(np.sum(sliced_img, axis=0), 101, 3))
    center_upper = calculate_center(img_sum_av[0])
    center_lower = calculate_center(img_sum_av[1])
    error_upper = calculate_error(img_sum_av[0], center_upper)
    error_lower = calculate_error(img_sum_av[1], center_lower)
    display_graph = False
    if (center_upper is not 0) and (center_lower is not 0):
        if (error_upper > 0) and (error_lower > 0):
            display_graph = True
            threshold_list.append(lowThreshold)
            if min_error_upper > error_upper:
                min_error_upper = error_upper
                min_error_upper_threshold = lowThreshold
            if min_error_lower > error_lower:
                min_error_lower = error_lower
                min_error_lower_threshold = lowThreshold
    if display_graph is True:
        center_upper = calculate_center(img_sum_av[0])
        center_lower = calculate_center(img_sum_av[1])
def execute_gru2_twitter_sentiment():

    from django.db.models import Avg, Max, Min
    from dbmgr.models import DJI, USTwitterNewsFeed
    import pandas as pd
    import logging
    from nn_models import gru
    from utils import normalize, calculate_error

    # PREPARE PD-FRAME PRICE DATA
    dji = list(
        DJI.objects.values_list('date', 'open', 'high', 'low', 'close',
                                'volume'))
    result = np.asarray(dji)
    description = np.asarray(
        ['date', 'open', 'high', 'low', 'close', 'volume'])

    data = pd.DataFrame(data=result, columns=description[:])
    data.set_index('date', inplace=True)
    data.index = data.index.map(str)

    for col in data.columns:
        data[col] = data[col].astype(dtype=np.float64)

    price_change = data.iloc[1:, [3]].values - data.iloc[:-1, [3]].values
    p_change = price_change / data.iloc[:-1, [3]].values
    price_change = np.concatenate(([[0]], price_change))
    p_change = np.concatenate(([[0]], p_change))
    data['price_change'] = price_change
    data['p_change'] = p_change * 100
    data = data.drop(data.index[0])  # remove first row

    # CALCULATE NORMALIZE Y
    train_y = data['close'].values / data['close'].values[0] - 1

    # DETERMINE AMOUNT OF DATA TO TRAIN ON
    RATIO = 1.0
    BGN_LENGTH = np.int(np.ceil(data.shape[0] * RATIO))

    # SELECT FEATURES FOR THE MODEL
    norm_data = pd.DataFrame(index=data.index)
    norm_data['close'] = data['close']
    norm_data['volume'] = data['volume']

    # TRAIN_Y INDEXED BY DATE
    y_data = pd.DataFrame(index=data.index)
    y_data['train_y'] = train_y

    # INCORPORATE SENTIMENT AS FEATURES -- TEST DATA
    news = USTwitterNewsFeed.objects.filter(
        created_datetime__range=["2018-02-01", "2018-02-28"])
    nnews = news.extra(select={'day': 'date( created_datetime )'}). \
        values('day').annotate(avg_sentiment=Avg('sentiment')).order_by('day')
    for news in nnews:
        norm_data.loc[str(news['day']), 'sentiment'] = news['avg_sentiment']

    # INCORPORATE SENTIMENT AS FEATURES -- TRAINING DATA
    news = USTwitterNewsFeed.objects.filter(
        created_datetime__range=["2017-11-01", "2017-12-01"])
    nnews = news.extra(select={'day': 'date( created_datetime )'}). \
        values('day').annotate(avg_sentiment=Avg('sentiment')).order_by('day')
    for news in nnews:
        norm_data.loc[str(news['day']), 'sentiment'] = news['avg_sentiment']

    test_data = norm_data["2018-02-26":"2018-02-02"]
    test_data = np.asarray(test_data, dtype=np.float64)
    test_data_y = np.asarray(y_data["2018-02-26":"2018-02-02"]["train_y"])

    norm_data = norm_data["2017-11-30":"2017-11-01"]
    norm_data = np.asarray(norm_data, dtype=np.float64)
    norm_data_y = np.asarray(y_data["2017-11-30":"2017-11-01"]["train_y"])

    # NORMALIZE FEATURES
    test_data = normalize(test_data)
    norm_data = normalize(norm_data)

    # CALCULATE MODEL DIMENSIONS
    input_dim = norm_data.shape[1]
    hidden_dim = input_dim * 2

    # TRAIN MODEL AND EVALUATE TRAINING LOSS ON NOVEMBER
    model = gru.GRU_2(input_dim, hidden_dim, seed=0)
    preds = gru.train_model(model, norm_data, norm_data_y, seed=0)

    mae, mape, rmse = calculate_error(norm_data_y[2:], preds)
    print('mae: ', mae)
    print('mape: ', mape)
    print('rmse: ', rmse)

    # TEST MODEL ON UNSEEN TEST DATA ON FEBRUARYtest_model
    preds = gru.test_model(model, test_data, test_data_y)

    mae, mape, rmse = calculate_error(test_data_y[2:], preds)
    print('mae: ', mae)
    print('mape: ', mape)
    print('rmse: ', rmse)
    return preds
Пример #6
0
    def train(self, *, train_loader, test_loader, num_warmup_epochs,
              num_outer_epochs, num_inner_epochs, inner_epoch_fraction,
              warmup_learning_rate, learning_rate, device, weight_decay,
              choose_random_iterate, **kwargs):
        """Executes training for the model over the given dataset and hyperparameters.

        Inputs:
        train_loader: torch.utils.data.DataLoader data loader for the training dataset
        test_loader: torch.utils.data.DataLoader data loader for the test dataset
        num_warmup_epochs: number of epochs to run SGD before starting SVRG
        num_outer_epochs: number of outer SVRG iterations
        num_inner_epochs: number of inner epochs to run for each outer epoch of SVRG.
        inner_epoch_fraction: if the number of inner iterations is not an integer number of epochs, this parameter
            can be used to specify the fraction of batches to iterate over. Only supported for less than a single epoch.
        warmup_learning_rate: the learning rate to use for SGD during the warmup phase.
        learning_rate: the learning rate to use for SVRG.
        device: string denoting the device to run on. "cuda" or "cpu" are expected.
        weight_decay: L2 regularization hyperparameter, used for both warmup for SVRG phases.
        choose_random_iterate: if True, a random inner iterate will be chosen for the weights to use for the next
            outer epoch. otherwise, it will use the last inner iterate.
        kwargs: any additional keyword arguments will be excepted but ignored.

        Returns:
        metrics: a list of dictionaries containing information about the run, including the training loss,
            gradient norm, and test error for each epoch.
        """
        print(
            'SVRGTrainer Hyperparameters:',
            json.dumps(
                {
                    'num_warmup_epochs': num_warmup_epochs,
                    'num_outer_epochs': num_outer_epochs,
                    'num_inner_epochs': num_inner_epochs,
                    'inner_epoch_fraction': inner_epoch_fraction,
                    'warmup_learning_rate': warmup_learning_rate,
                    'learning_rate': learning_rate,
                    'device': device,
                    'weight_decay': weight_decay,
                    'choose_random_iterate': choose_random_iterate
                },
                indent=2))
        print('Unused kwargs:', kwargs)

        device = torch.device(device)
        metrics = []

        model = self.create_model().to(device)
        target_model = self.create_model().to(device)
        print(model)

        # Perform several epochs of SGD as initialization for SVRG
        warmup_optimizer = torch.optim.SGD(target_model.parameters(),
                                           lr=warmup_learning_rate,
                                           weight_decay=weight_decay)
        for warmup_epoch in range(1, num_warmup_epochs + 1):
            warmup_loss = 0
            epoch_start = time.time()
            for batch in train_loader:
                data, label = (x.to(device) for x in batch)
                warmup_optimizer.zero_grad()
                prediction = target_model(data.to(device))
                loss = self.loss_fn(prediction, label.to(device))
                loss.backward()
                warmup_optimizer.step()
                warmup_loss += loss.item() * len(data)
            avg_warmup_loss = warmup_loss / len(train_loader.dataset)
            model_grad_norm = utils.calculate_full_gradient_norm(
                model=target_model,
                data_loader=train_loader,
                loss_fn=self.loss_fn,
                device=device)
            test_error = utils.calculate_error(model=target_model,
                                               data_loader=test_loader,
                                               device=device)
            elapsed_time = time.time() - epoch_start
            ex_per_sec = len(train_loader.dataset) / elapsed_time
            metrics.append({
                'warmup_epoch': warmup_epoch,
                'train_loss': avg_warmup_loss,
                'grad_norm': model_grad_norm,
                'test_error': test_error
            })
            print(
                '[Warmup {}/{}] loss: {:.04f}, grad_norm: {:.02f}, test_error: {:.04f}, (1k) ex/s: {:.02f}'
                .format(warmup_epoch, num_warmup_epochs, avg_warmup_loss,
                        model_grad_norm, test_error, ex_per_sec / 1000))

        for epoch in range(1, num_outer_epochs + 1):
            # Find full target gradient
            mu = utils.calculate_full_gradient(model=target_model,
                                               data_loader=train_loader,
                                               loss_fn=self.loss_fn,
                                               device=device)

            # Initialize model to target model
            model.load_state_dict(copy.deepcopy(target_model.state_dict()))

            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=learning_rate,
                                        weight_decay=weight_decay)
            model_state_dicts = []
            inner_batches = len(train_loader)
            if inner_epoch_fraction is not None:
                inner_batches = int(len(train_loader) * inner_epoch_fraction)
            for sub_epoch in range(1, num_inner_epochs + 1):
                train_loss = 0
                examples_seen = 0
                epoch_start = time.time()
                for batch_idx, batch in enumerate(train_loader):
                    data, label = (x.to(device) for x in batch)
                    optimizer.zero_grad()

                    # Calculate target model gradient
                    target_model.zero_grad()
                    target_model_out = target_model(data)
                    target_model_loss = self.loss_fn(target_model_out, label)
                    target_model_loss.backward()
                    target_model_grad = torch.cat([
                        x.grad.view(-1) for x in target_model.parameters()
                    ]).detach()

                    # Calculate current model loss
                    model_weights = torch.cat(
                        [x.view(-1) for x in model.parameters()])
                    model_out = model(data)
                    model_loss = self.loss_fn(model_out, label)

                    # Use SGD on auxiliary loss function
                    # See the SVRG paper section 2 for details
                    aux_loss = model_loss - \
                        torch.dot((target_model_grad - mu).detach(),
                                  model_weights)
                    aux_loss.backward()
                    optimizer.step()

                    # Bookkeeping
                    train_loss += model_loss.item() * len(data)
                    examples_seen += len(data)
                    copy_state_dict = copy.deepcopy(model.state_dict())
                    # Copy model parameters to CPU first to prevent GPU overflow
                    for k, v in copy_state_dict.items():
                        copy_state_dict[k] = v.cpu()
                    model_state_dicts.append(copy_state_dict)

                    batch_num = batch_idx + 1
                    if batch_num >= inner_batches:
                        break
                # Calculate metrics for logging
                avg_train_loss = train_loss / examples_seen
                model_grad_norm = utils.calculate_full_gradient_norm(
                    model=model,
                    data_loader=train_loader,
                    loss_fn=self.loss_fn,
                    device=device)
                test_error = utils.calculate_error(model=model,
                                                   data_loader=test_loader,
                                                   device=device)
                elapsed_time = time.time() - epoch_start
                ex_per_sec = len(train_loader.dataset) / elapsed_time
                metrics.append({
                    'outer_epoch': epoch,
                    'inner_epoch': sub_epoch,
                    'train_loss': avg_train_loss,
                    'grad_norm': model_grad_norm,
                    'test_error': test_error
                })
                print(
                    '[Outer {}/{}, Inner {}/{}] loss: {:.04f}, grad_norm: {:.02f}, test_error: {:.04f}, (1k) ex/s: {:.02f}'
                    .format(epoch, num_outer_epochs, sub_epoch,
                            num_inner_epochs, avg_train_loss, model_grad_norm,
                            test_error, ex_per_sec / 1000))  # noqa

            # This choice corresponds to options I and II from the SVRG paper. Depending on the hyperparameter it
            # will either choose the last inner iterate for the target model, or use a random iterate.
            if choose_random_iterate:
                new_target_state_dict = random.choice(model_state_dicts)
            else:
                new_target_state_dict = model_state_dicts[-1]
            target_model.load_state_dict(new_target_state_dict)
        return metrics
    def run(self, learning_rate, eta):
        for input_set_number in range(len(self.inputs)):
            for i in range(self.layers[0].size()):
                self.layers[0][i].input_value = self.inputs[input_set_number][
                    i]

            for layer in self.layers:
                if layer.get_output_to_layer() is not None:
                    for output_neuron_number in range(
                            layer.get_output_to_layer().size()):
                        output_value = 0

                        for input_neuron_number in range(layer.size()):
                            output_value += layer[input_neuron_number].get_activated_output() \
                                            * layer[input_neuron_number].weights[output_neuron_number]

                        layer.get_output_to_layer(
                        )[output_neuron_number].input_value = output_value

            output_layer = self.layers[-1]
            error = utils.calculate_error(
                self.formal_outputs[input_set_number],
                [neuron.get_activated_output() for neuron in output_layer])

            current_layer = output_layer
            previous_layer = current_layer.input_from_layer

            # calculate and set the value "d" for each neuron
            # vector d is calculated for final o/p layer only
            for output_neuron_number in range(output_layer.size()):
                formal_output = self.formal_outputs[input_set_number][
                    output_neuron_number]
                actual_output = output_layer[
                    output_neuron_number].get_activated_output()

                output_layer[output_neuron_number].d_value = (
                    formal_output -
                    actual_output) * actual_output * (1 - actual_output)

            d_vector = numpy.array(
                [[neuron.d_value for neuron in current_layer]])

            # calculate matrix Y
            output_of_previous_layer = numpy.array(
                [[neuron.get_activated_output() for neuron in previous_layer]])
            output_of_previous_layer = output_of_previous_layer.T
            matrix_y = numpy.matmul(output_of_previous_layer, d_vector)

            # calculate delta w
            old_delta_w = numpy.array(
                [neuron.last_change_in_weights for neuron in previous_layer])
            new_delta_w = (learning_rate * old_delta_w) + (eta * matrix_y)

            # calculate vector e
            matrix_w = [neuron.weights for neuron in previous_layer]
            vector_e = numpy.matmul(matrix_w, d_vector)

            # calculate d*
            d_star = []
            for neuron_number in range(previous_layer.size()):
                OHi = previous_layer[neuron_number].get_activated_output()
                d_star_value = (vector_e[neuron_number] * OHi) * (1 - OHi)
                d_star.append(d_star_value)

            # update the weights by adding del w and set new delta_w as old delta_w
            for neuron_number in range(previous_layer.size()):
                new_weights = []
                old_weights = previous_layer[neuron_number].weights
                previous_layer.last_change_in_weights = new_delta_w[
                    neuron_number]

                for weight_number in range(len(old_weights)):
                    new_weights.append(
                        old_weights[weight_number] +
                        new_delta_w[neuron_number][weight_number])

                previous_layer[neuron_number].weights = new_weights

            current_layer = previous_layer  # propagate back
            previous_layer = previous_layer.input_from_layer

            while previous_layer is not None:

                # calculate vector e
                matrix_w = [neuron.weights for neuron in previous_layer]
                vector_e = numpy.matmul(matrix_w, d_vector)

                # calculate d*
                d_star = []
                for neuron_number in range(previous_layer.size()):
                    OHi = previous_layer[neuron_number].get_activated_output()
                    d_star_value = (vector_e[neuron_number] * OHi) * (1 - OHi)
                    d_star.append(d_star_value)

                # update the weights by adding del w and set new delta_w as old delta_w
                for neuron_number in range(previous_layer.size()):
                    new_weights = []
                    old_weights = previous_layer[neuron_number].weights
                    previous_layer.last_change_in_weights = new_delta_w[
                        neuron_number]

                    for weight_number in range(len(old_weights)):
                        new_weights.append(
                            old_weights[weight_number] +
                            new_delta_w[neuron_number][weight_number])

                    previous_layer[neuron_number].weights = new_weights

                current_layer = previous_layer  # propagate back
                previous_layer = previous_layer.input_from_layer

            print(
                "------------------------ Iteration {} ------------------------"
                .format(input_set_number))
            print("Error : {:.2f}".format(error))
            self.print_network()
            print(
                "--------------------------------------------------------------\n\n"
            )
Пример #8
0
    def train(self, *, train_loader, test_loader, num_warmup_epochs,
              num_outer_epochs, num_inner_epochs, inner_epoch_fraction,
              warmup_learning_rate, learning_rate, device, weight_decay,
              choose_random_iterate, **kwargs):
        print(
            'SVRGTrainer Hyperparameters:',
            json.dumps(
                {
                    'num_warmup_epochs': num_warmup_epochs,
                    'num_outer_epochs': num_outer_epochs,
                    'num_inner_epochs': num_inner_epochs,
                    'inner_epoch_fraction': inner_epoch_fraction,
                    'warmup_learning_rate': warmup_learning_rate,
                    'learning_rate': learning_rate,
                    'device': device,
                    'weight_decay': weight_decay,
                    'choose_random_iterate': choose_random_iterate
                },
                indent=2))
        print('Unused kwargs:', kwargs)

        device = torch.device(device)
        metrics = []

        model = self.create_model().to(device)
        target_model = self.create_model().to(device)
        print(model)

        # Perform several epochs of SGD as initialization for SVRG
        warmup_optimizer = torch.optim.SGD(target_model.parameters(),
                                           lr=warmup_learning_rate,
                                           weight_decay=weight_decay)
        for warmup_epoch in range(1, num_warmup_epochs + 1):
            warmup_loss = 0
            epoch_start = time.time()
            for batch in train_loader:
                data, label = (x.to(device) for x in batch)
                warmup_optimizer.zero_grad()
                prediction = target_model(data.to(device))
                loss = self.loss_fn(prediction, label.to(device))
                loss.backward()
                warmup_optimizer.step()
                warmup_loss += loss.item() * len(data)
            avg_warmup_loss = warmup_loss / len(train_loader.dataset)
            model_grad_norm = utils.calculate_full_gradient_norm(
                model=target_model,
                data_loader=train_loader,
                loss_fn=self.loss_fn,
                device=device)
            test_error = utils.calculate_error(model=target_model,
                                               data_loader=test_loader,
                                               device=device)
            elapsed_time = time.time() - epoch_start
            ex_per_sec = len(train_loader.dataset) / elapsed_time
            metrics.append({
                'warmup_epoch': warmup_epoch,
                'train_loss': avg_warmup_loss,
                'grad_norm': model_grad_norm,
                'test_error': test_error
            })
            print(
                '[Warmup {}/{}] loss: {:.04f}, grad_norm: {:.02f}, test_error: {:.04f}, (1k) ex/s: {:.02f}'
                .format(warmup_epoch, num_warmup_epochs, avg_warmup_loss,
                        model_grad_norm, test_error, ex_per_sec / 1000))

        for epoch in range(1, num_outer_epochs + 1):
            # Find full target gradient
            mu = utils.calculate_full_gradient(model=target_model,
                                               data_loader=train_loader,
                                               loss_fn=self.loss_fn,
                                               device=device)

            # Initialize model to target model
            model.load_state_dict(copy.deepcopy(target_model.state_dict()))

            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=learning_rate,
                                        weight_decay=weight_decay)
            model_state_dicts = []
            inner_batches = len(train_loader)
            if inner_epoch_fraction is not None:
                inner_batches = int(len(train_loader) * inner_epoch_fraction)
            for sub_epoch in range(1, num_inner_epochs + 1):
                train_loss = 0
                examples_seen = 0
                epoch_start = time.time()
                for batch_idx, batch in enumerate(train_loader):

                    data, label = (x.to(device) for x in batch)
                    optimizer.zero_grad()

                    target_model.zero_grad()
                    target_model_out = target_model(data)
                    target_model_loss = self.loss_fn(target_model_out, label)
                    target_model_loss.backward()
                    target_model_grad = torch.cat([
                        x.grad.view(-1) for x in target_model.parameters()
                    ]).detach()

                    model_weights = torch.cat(
                        [x.view(-1) for x in model.parameters()])
                    model_out = model(data)
                    model_loss = self.loss_fn(model_out, label)

                    # Use SGD on auxiliary loss function
                    # See the SVRG paper section 2 for details
                    aux_loss = model_loss - \
                        torch.dot((target_model_grad - mu).detach(),
                                  model_weights)
                    aux_loss.backward()
                    optimizer.step()

                    train_loss += model_loss.item() * len(data)
                    examples_seen += len(data)
                    copy_state_dict = copy.deepcopy(model.state_dict())
                    # Copy model parameters to CPU first to prevent GPU overflow
                    for k, v in copy_state_dict.items():
                        copy_state_dict[k] = v.cpu()
                    model_state_dicts.append(copy_state_dict)

                    batch_num = batch_idx + 1
                    if batch_num >= inner_batches:
                        break
                avg_train_loss = train_loss / examples_seen
                model_grad_norm = utils.calculate_full_gradient_norm(
                    model=model,
                    data_loader=train_loader,
                    loss_fn=self.loss_fn,
                    device=device)
                test_error = utils.calculate_error(model=model,
                                                   data_loader=test_loader,
                                                   device=device)
                elapsed_time = time.time() - epoch_start
                ex_per_sec = len(train_loader.dataset) / elapsed_time
                metrics.append({
                    'outer_epoch': epoch,
                    'inner_epoch': sub_epoch,
                    'train_loss': avg_train_loss,
                    'grad_norm': model_grad_norm,
                    'test_error': test_error
                })
                print(
                    '[Outer {}/{}, Inner {}/{}] loss: {:.04f}, grad_norm: {:.02f}, test_error: {:.04f}, (1k) ex/s: {:.02f}'
                    .format(epoch, num_outer_epochs, sub_epoch,
                            num_inner_epochs, avg_train_loss, model_grad_norm,
                            test_error, ex_per_sec / 1000))  # noqa

            if choose_random_iterate:
                new_target_state_dict = random.choice(model_state_dicts)
            else:
                new_target_state_dict = model_state_dicts[-1]
            target_model.load_state_dict(new_target_state_dict)
        return metrics