예제 #1
0
class MnistOptimizee:
    def __init__(self, seed, n_ensembles, batch_size, root, path):

        self.random_state = np.random.RandomState(seed=seed)

        self.n_ensembles = n_ensembles
        self.batch_size = batch_size
        self.root = root
        self.path = path

        self.conv_net = ConvNet().to(device)
        self.criterion = nn.CrossEntropyLoss()
        self.data_loader = DataLoader()
        self.data_loader.init_iterators(self.root, self.path, self.batch_size)
        self.dataiter_mnist = self.data_loader.dataiter_mnist
        self.testiter_mnist = self.data_loader.testiter_mnist
        self.dataiter_notmnist = self.data_loader.dataiter_notmnist
        self.testiter_notmnist = self.data_loader.testiter_notmnist

        self.generation = 0
        self.inputs, self.labels = self.dataiter_mnist()
        self.test_input, self.test_label = self.testiter_mnist()

        # For plotting
        self.train_pred = []
        self.test_pred = []
        self.train_acc = []
        self.test_acc = []
        self.train_cost = []
        self.test_cost = []
        self.targets = []
        self.output_activity_train = []
        self.output_activity_test = []

        self.targets.append(self.labels.numpy())

    def create_individual(self):
        # get weights, biases from networks and flatten them
        # convolutional network parameters ###
        conv_ensembles = []
        with torch.no_grad():
            # weights for layers, conv1, conv2, fc1
            # conv1_weights = self.conv_net.state_dict()['conv1.weight'].view(-1).numpy()
            # conv2_weights = self.conv_net.state_dict()['conv2.weight'].view(-1).numpy()
            # fc1_weights = self.conv_net.state_dict()['fc1.weight'].view(-1).numpy()

            # bias
            # conv1_bias = self.conv_net.state_dict()['conv1.bias'].numpy()
            # conv2_bias = self.conv_net.state_dict()['conv2.bias'].numpy()
            # fc1_bias = self.conv_net.state_dict()['fc1.bias'].numpy()

            # stack everything into a vector of
            # conv1_weights, conv1_bias, conv2_weights, conv2_bias,
            # fc1_weights, fc1_bias
            # params = np.hstack((conv1_weights, conv1_bias, conv2_weights,
            #                     conv2_bias, fc1_weights, fc1_bias))
            length = 0
            for key in self.conv_net.state_dict().keys():
                length += self.conv_net.state_dict()[key].nelement()

            # conv_ensembles.append(params)
            # l = np.random.uniform(-.5, .5, size=length)
            for _ in range(self.n_ensembles):
                #     stacked = []
                #     for key in self.conv_net.state_dict().keys():
                #         stacked.extend(
                #             self._he_init(self.conv_net.state_dict()[key]))
                #     conv_ensembles.append(stacked)
                # tmp = []
                # for j in l:
                #     jitter = np.random.uniform(-0.1, 0.1) + j
                #     tmp.append(jitter)
                # conv_ensembles.append(tmp)
                conv_ensembles.append(np.random.uniform(-1, 1, size=length))
            # convert targets to numpy()
            targets = tuple([label.numpy() for label in self.labels])
            return dict(conv_params=np.array(conv_ensembles),
                        targets=targets,
                        input=self.inputs.squeeze().numpy())

    @staticmethod
    def _he_init(weights, gain=0):
        """
        He- or Kaiming- initialization as in He et al., "Delving deep into
        rectifiers: Surpassing human-level performance on ImageNet
        classification". Values are sampled from
        :math:`\\mathcal{N}(0, \\text{std})` where

        .. math::
        \text{std} = \\sqrt{\\frac{2}{(1 + a^2) \\times \text{fan\\_in}}}

        Note: Only for the case that the non-linearity of the network
            activation is `relu`

        :param weights, tensor
        :param gain, additional scaling factor, Default is 0
        :return: numpy nd array, random array of size `weights`
        """
        fan_in = torch.nn.init._calculate_correct_fan(weights, 'fan_in')
        stddev = np.sqrt(2. / fan_in * (1 + gain**2))
        return np.random.normal(0, stddev, weights.numel())

    def set_parameters(self, ensembles):
        # set the new parameter for the network
        conv_params = np.mean(ensembles, axis=0)
        ds = self._shape_parameter_to_conv_net(conv_params)
        self.conv_net.set_parameter(ds)
        print('---- Train -----')
        print('Generation ', self.generation)
        generation_change = 8
        dataset_change = 500
        with torch.no_grad():
            inputs = self.inputs
            labels = self.labels

            if self.generation % generation_change == 0 and self.generation < dataset_change:
                self.inputs, self.labels = self.dataiter_mnist()
                print('New MNIST set used at generation {}'.format(
                    self.generation))
                # append the outputs
                self.targets.append(self.labels.numpy())
            elif model.generation >= dataset_change and self.generation % generation_change == 0:
                if model.generation == generation_change:
                    self.test_input, self.test_label = self.testiter_notmnist()
                self.inputs, self.labels = self.dataiter_notmnist()
                print('New notMNIST set used at generation {}'.format(
                    self.generation))
                # append the outputs
                self.targets.append(self.labels.numpy())

            outputs = self.conv_net(inputs)
            self.output_activity_train.append(outputs.numpy())
            conv_loss = self.criterion(outputs, labels).item()
            train_cost = _calculate_cost(_encode_targets(labels, 10),
                                         outputs.numpy(), 'MSE')
            train_acc = score(labels.numpy(), np.argmax(outputs.numpy(), 1))
            print('Cost: ', train_cost)
            print('Accuracy: ', train_acc)
            print('Loss:', conv_loss)
            self.train_cost.append(train_cost)
            self.train_acc.append(train_acc)
            self.train_pred.append(np.argmax(outputs.numpy(), 1))

            print('---- Test -----')
            test_output = self.conv_net(self.test_input)
            test_output = test_output.numpy()
            test_acc = score(self.test_label.numpy(),
                             np.argmax(test_output, 1))
            test_cost = _calculate_cost(_encode_targets(self.test_label, 10),
                                        test_output, 'MSE')
            print('Test accuracy', test_acc)
            self.test_acc.append(test_acc)
            self.test_pred.append(np.argmax(test_output, 1))
            self.test_cost.append(test_cost)
            self.output_activity_test.append(test_output)
            print('-----------------')
            conv_params = []
            for c in ensembles:
                ds = self._shape_parameter_to_conv_net(c)
                self.conv_net.set_parameter(ds)
                conv_params.append(self.conv_net(inputs).numpy().T)

            outs = {
                'conv_params': np.array(conv_params),
                'conv_loss': float(conv_loss),
                'input': self.inputs.squeeze().numpy(),
                'targets': self.labels.numpy()
            }
        return outs

    def _shape_parameter_to_conv_net(self, params):
        param_dict = dict()
        start = 0
        for key in self.conv_net.state_dict().keys():
            shape = self.conv_net.state_dict()[key].shape
            length = self.conv_net.state_dict()[key].nelement()
            end = start + length
            param_dict[key] = params[start:end].reshape(shape)
            start = end
        return param_dict
예제 #2
0
            loss = criterion(output, target)
            test_loss += loss.item()
            # network prediction
            pred = output.argmax(1, keepdim=True)
            # how many image are correct classified, compare with targets
            ta = pred.eq(target.view_as(pred)).sum().item()
            test_accuracy += ta
            test_acc = score(target.cpu().numpy(),
                             np.argmax(output.cpu().numpy(), 1))
            if idx % 10 == 0:
                print('Test Loss {}, idx {}'.format(
                    loss.item(), idx))

        print('Test accuracy: {} Average test loss: {}'.format(
            100 * test_accuracy / len(test_loader_mnist.dataset),
            test_loss / len(test_loader_mnist.dataset)))


if __name__ == '__main__':
    conv_params = torch.load('conv_params.pt', map_location='cpu')
    ensemble = conv_params['ensemble'].mean(0)
    ensemble = torch.from_numpy(ensemble)
    conv_net = ConvNet()
    ds = shape_parameter_to_conv_net(conv_net, ensemble)
    conv_net.set_parameter(ds)
    criterion = torch.nn.CrossEntropyLoss()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch = 64
    train_loader, test_loader = get_data(batch, device)
    test(conv_net, test_loader)
class MnistOptimizee(torch.nn.Module):
    def __init__(self, n_ensembles, batch_size, root):
        super(MnistOptimizee, self).__init__()

        self.n_ensembles = n_ensembles
        self.batch_size = batch_size
        self.root = root
        self.length = 0

        self.conv_net = ConvNet().to(device)
        self.criterion = nn.CrossEntropyLoss()
        self.data_loader = DataLoader()
        self.data_loader.init_iterators(self.root, self.batch_size)
        self.dataiter_mnist = self.data_loader.dataiter_mnist
        self.testiter_mnist = self.data_loader.testiter_mnist

        self.generation = 0  # iterations
        self.inputs, self.labels = self.dataiter_mnist()
        self.inputs = self.inputs.to(device)
        self.labels = self.labels.to(device)
        self.test_input, self.test_label = self.testiter_mnist()
        self.test_input = self.test_input.to(device)
        self.test_label = self.test_label.to(device)

        # For plotting
        self.train_pred = []
        self.test_pred = []
        self.train_acc = []
        self.test_acc = []
        self.train_cost = []
        self.test_cost = []
        self.train_loss = []
        self.test_loss = []
        self.targets = []
        self.output_activity_train = []
        self.output_activity_test = []
        self.act_func = {
            'act1': [],
            'act2': [],
            'act1_mean': [],
            'act2_mean': [],
            'act1_std': [],
            'act2_std': [],
            'act3': [],
            'act3_mean': [],
            'act3_std': []
        }
        self.test_act_func = {
            'act1': [],
            'act2': [],
            'act1_mean': [],
            'act2_mean': [],
            'act1_std': [],
            'act2_std': [],
            'act3': [],
            'act3_mean': [],
            'act3_std': []
        }

        self.targets.append(self.labels.cpu().numpy())

        for key in self.conv_net.state_dict().keys():
            self.length += self.conv_net.state_dict()[key].nelement()

    def create_individual(self):
        # get weights, biases from networks and flatten them
        # convolutional network parameters
        conv_ensembles = []
        sigma = config['sigma']
        with torch.no_grad():
            for _ in range(self.n_ensembles):
                conv_ensembles.append(
                    np.random.normal(0, sigma, size=self.length))
            return dict(conv_params=torch.as_tensor(np.array(conv_ensembles),
                                                    device=device),
                        targets=self.labels,
                        input=self.inputs.squeeze())

    def load_model(self, path='conv_params.npy'):
        print('Loading model from path: {}'.format(path))
        conv_params = np.load(path).item()
        conv_ensembles = conv_params.get('ensemble')
        return dict(conv_params=torch.as_tensor(np.array(conv_ensembles),
                                                device=device),
                    targets=self.labels,
                    input=self.inputs.squeeze())

    def set_parameters(self, ensembles):
        # set the new parameter for the network
        conv_params = ensembles.mean(0)
        ds = self._shape_parameter_to_conv_net(conv_params)
        self.conv_net.set_parameter(ds)
        print('---- Train -----')
        print('Iteration ', self.generation)
        generation_change = config['repetitions']
        with torch.no_grad():
            inputs = self.inputs.to(device)
            labels = self.labels.to(device)

            if self.generation % generation_change == 0:
                self.inputs, self.labels = self.dataiter_mnist()
                self.inputs = self.inputs.to(device)
                self.labels = self.labels.to(device)
                print('New MNIST set used at generation {}'.format(
                    self.generation))
                # append the outputs
                self.targets.append(self.labels.cpu().numpy())
            # get network predicitions
            outputs, act1, act2 = self.conv_net(inputs)
            act3 = outputs
            # save all important calculations
            self.act_func['act1'] = act1.cpu().numpy()
            self.act_func['act2'] = act2.cpu().numpy()
            self.act_func['act3'] = act3.cpu().numpy()
            self.act_func['act1_mean'].append(act1.mean().item())
            self.act_func['act2_mean'].append(act2.mean().item())
            self.act_func['act3_mean'].append(act3.mean().item())
            self.act_func['act1_std'].append(act1.std().item())
            self.act_func['act2_std'].append(act2.std().item())
            self.act_func['act3_std'].append(act3.std().item())
            self.output_activity_train.append(
                F.softmax(outputs, dim=1).cpu().numpy())
            conv_loss = self.criterion(outputs, labels).item()
            self.train_loss.append(conv_loss)
            train_cost = _calculate_cost(
                _encode_targets(labels, 10),
                F.softmax(outputs, dim=1).cpu().numpy(), 'MSE')
            train_acc = score(
                labels.cpu().numpy(),
                np.argmax(F.softmax(outputs, dim=1).cpu().numpy(), 1))
            print('Cost: ', train_cost)
            print('Accuracy: ', train_acc)
            print('Loss:', conv_loss)
            self.train_cost.append(train_cost)
            self.train_acc.append(train_acc)
            self.train_pred.append(
                np.argmax(F.softmax(outputs, dim=1).cpu().numpy(), 1))

            print('---- Test -----')
            test_output, act1, act2 = self.conv_net(self.test_input)
            test_loss = self.criterion(test_output, self.test_label).item()
            self.test_act_func['act1'] = act1.cpu().numpy()
            self.test_act_func['act2'] = act2.cpu().numpy()
            self.test_act_func['act1_mean'].append(act1.mean().item())
            self.test_act_func['act2_mean'].append(act2.mean().item())
            self.test_act_func['act3_mean'].append(test_output.mean().item())
            self.test_act_func['act1_std'].append(act1.std().item())
            self.test_act_func['act2_std'].append(act2.std().item())
            self.test_act_func['act3_std'].append(test_output.std().item())
            test_output = test_output.cpu().numpy()
            self.test_act_func['act3'] = test_output
            test_acc = score(self.test_label.cpu().numpy(),
                             np.argmax(test_output, 1))
            test_cost = _calculate_cost(
                _encode_targets(self.test_label.cpu().numpy(), 10),
                test_output, 'MSE')
            print('Test accuracy', test_acc)
            print('Test loss: {}'.format(test_loss))
            self.test_acc.append(test_acc)
            self.test_pred.append(np.argmax(test_output, 1))
            self.test_cost.append(test_cost)
            self.output_activity_test.append(test_output)
            self.test_loss.append(test_loss)
            print('-----------------')
            conv_params = []
            for idx, c in enumerate(ensembles):
                ds = self._shape_parameter_to_conv_net(c)
                self.conv_net.set_parameter(ds)
                params, _, _ = self.conv_net(inputs)
                conv_params.append(params.t().cpu().numpy())

            outs = {
                'conv_params': torch.tensor(conv_params).to(device),
                'conv_loss': float(conv_loss),
                'input': self.inputs.squeeze(),
                'targets': self.labels
            }
        return outs

    def _shape_parameter_to_conv_net(self, params):
        """
        Create a dictionary which includes the shapes of the network
        architecture per layer
        """
        param_dict = dict()
        start = 0
        for key in self.conv_net.state_dict().keys():
            shape = self.conv_net.state_dict()[key].shape
            length = self.conv_net.state_dict()[key].nelement()
            end = start + length
            param_dict[key] = params[start:end].reshape(shape)
            start = end
        return param_dict
예제 #4
0
class MnistOptimizee(torch.nn.Module):
    def __init__(self, seed, n_ensembles, batch_size, root):
        super(MnistOptimizee, self).__init__()

        self.random_state = np.random.RandomState(seed=seed)

        self.n_ensembles = n_ensembles
        self.batch_size = batch_size
        self.root = root

        self.conv_net = ConvNet().to(device)
        self.criterion = nn.CrossEntropyLoss()
        self.data_loader = DataLoader()
        self.data_loader.init_iterators(self.root, self.batch_size)
        self.dataiter_mnist = self.data_loader.dataiter_mnist
        self.testiter_mnist = self.data_loader.testiter_mnist

        self.generation = 0
        self.inputs, self.labels = self.dataiter_mnist()
        self.inputs = self.inputs.to(device)
        self.labels = self.labels.to(device)
        self.test_input, self.test_label = self.testiter_mnist()
        self.test_input = self.test_input.to(device)
        self.test_label = self.test_label.to(device)

        self.timings = {
            'shape_parameters': [],
            'set_parameters': [],
            'set_parameters_cnn': [],
            'shape_parameters_ens': []
        }

        self.length = 0
        for key in self.conv_net.state_dict().keys():
            self.length += self.conv_net.state_dict()[key].nelement()

    def create_individual(self):
        # get weights, biases from networks and flatten them
        # convolutional network parameters ###
        conv_ensembles = []
        with torch.no_grad():
            # weights for layers, conv1, conv2, fc1
            # conv1_weights = self.conv_net.state_dict()['conv1.weight'].view(-1).numpy()
            # conv2_weights = self.conv_net.state_dict()['conv2.weight'].view(-1).numpy()
            # fc1_weights = self.conv_net.state_dict()['fc1.weight'].view(-1).numpy()

            # bias
            # conv1_bias = self.conv_net.state_dict()['conv1.bias'].numpy()
            # conv2_bias = self.conv_net.state_dict()['conv2.bias'].numpy()
            # fc1_bias = self.conv_net.state_dict()['fc1.bias'].numpy()

            # stack everything into a vector of
            # conv1_weights, conv1_bias, conv2_weights, conv2_bias,
            # fc1_weights, fc1_bias
            # params = np.hstack((conv1_weights, conv1_bias, conv2_weights,
            #                     conv2_bias, fc1_weights, fc1_bias))
            # length = 0
            # for key in self.conv_net.state_dict().keys():
            #     length += self.conv_net.state_dict()[key].nelement()

            # conv_ensembles.append(params)
            # l = np.random.uniform(-.5, .5, size=length)
            for _ in range(self.n_ensembles):
                #     stacked = []
                #     for key in self.conv_net.state_dict().keys():
                #         stacked.extend(
                #             self._he_init(self.conv_net.state_dict()[key]))
                #     conv_ensembles.append(stacked)
                # tmp = []
                # for j in l:
                #     jitter = np.random.uniform(-0.1, 0.1) + j
                #     tmp.append(jitter)
                # conv_ensembles.append(tmp)
                conv_ensembles.append(
                    np.random.normal(0, 0.1, size=self.length))
            return dict(conv_params=torch.as_tensor(conv_ensembles,
                                                    device=device),
                        targets=self.labels,
                        input=self.inputs.squeeze())

    def load_model(self, path='conv_params.npy'):
        print('Loading model from path: {}'.format(path))
        conv_params = np.load(path).item()
        conv_ensembles = conv_params.get('ensemble')
        return dict(conv_params=torch.as_tensor(conv_ensembles, device=device),
                    targets=self.labels,
                    input=self.inputs.squeeze())

    @staticmethod
    def _he_init(weights, gain=0):
        """
        He- or Kaiming- initialization as in He et al., "Delving deep into
        rectifiers: Surpassing human-level performance on ImageNet
        classification". Values are sampled from
        :math:`\\mathcal{N}(0, \\text{std})` where

        .. math::
        \text{std} = \\sqrt{\\frac{2}{(1 + a^2) \\times \text{fan\\_in}}}

        Note: Only for the case that the non-linearity of the network
            activation is `relu`

        :param weights, tensor
        :param gain, additional scaling factor, Default is 0
        :return: numpy nd array, random array of size `weights`
        """
        fan_in = torch.nn.init._calculate_correct_fan(weights, 'fan_in')
        stddev = np.sqrt(2. / fan_in * (1 + gain**2))
        return np.random.normal(0, stddev, weights.numel())

    def set_parameters(self, ensembles):
        # set the new parameter for the network
        conv_params = ensembles.mean(0)
        t = time.time()
        ds = self._shape_parameter_to_conv_net(conv_params)
        self.timings['shape_parameters_ens'].append(time.time() - t)
        t = time.time()
        self.conv_net.set_parameter(ds)
        self.timings['set_parameters_cnn'].append(time.time() - t)
        print('---- Train -----')
        print('Generation ', self.generation)
        generation_change = 1
        with torch.no_grad():
            inputs = self.inputs.to(device)
            labels = self.labels.to(device)

            if self.generation % generation_change == 0:
                self.inputs, self.labels = self.dataiter_mnist()
                self.inputs = self.inputs.to(device)
                self.labels = self.labels.to(device)
                print('New MNIST set used at generation {}'.format(
                    self.generation))

            outputs, act1, act2 = self.conv_net(inputs)
            conv_loss = self.criterion(outputs, labels).item()
            train_cost = _calculate_cost(_encode_targets(labels, 10),
                                         F.softmax(outputs, dim=1), 'MSE')
            train_acc = score(labels, torch.argmax(F.softmax(outputs, dim=1),
                                                   1))
            print('Cost: ', train_cost)
            print('Accuracy: ', train_acc)
            print('Loss:', conv_loss)

            print('---- Test -----')
            test_output, act1, act2 = self.conv_net(self.test_input)
            test_loss = self.criterion(test_output, self.test_label).item()
            test_acc = score(self.test_label, torch.argmax(test_output, 1))
            test_cost = _calculate_cost(_encode_targets(self.test_label, 10),
                                        test_output, 'MSE')
            print('Test accuracy', test_acc)
            print('Test loss: {}'.format(test_loss))
            print('-----------------')
            conv_params = []
            for idx, c in enumerate(ensembles):
                t = time.time()
                ds = self._shape_parameter_to_conv_net(c)
                self.timings['shape_parameters'].append(time.time() - t)
                t = time.time()
                self.conv_net.set_parameter(ds)
                self.timings['set_parameters'].append(time.time() - t)
                params, _, _ = self.conv_net(inputs)
                conv_params.append(params.t())
            conv_params = torch.stack(conv_params)
            outs = {
                'conv_params': conv_params,
                'conv_loss': float(conv_loss),
                'input': self.inputs.squeeze(),
                'targets': self.labels
            }
        return outs

    def _shape_parameter_to_conv_net(self, params):
        param_dict = dict()
        start = 0
        for key in self.conv_net.state_dict().keys():
            shape = self.conv_net.state_dict()[key].shape
            length = self.conv_net.state_dict()[key].nelement()
            end = start + length
            param_dict[key] = params[start:end].reshape(shape)
            start = end
        return param_dict
예제 #5
0
class MnistOptimizee(torch.nn.Module):
    def __init__(self, seed, n_ensembles, batch_size, root):
        super(MnistOptimizee, self).__init__()

        self.random_state = np.random.RandomState(seed=seed)

        self.n_ensembles = n_ensembles
        self.batch_size = batch_size
        self.root = root

        self.conv_net = ConvNet().to(device)
        self.criterion = nn.CrossEntropyLoss()
        self.data_loader = DataLoader()
        self.data_loader.init_iterators(self.root, self.batch_size)
        self.dataiter_mnist = self.data_loader.dataiter_mnist
        self.testiter_mnist = self.data_loader.testiter_mnist

        self.generation = 0
        self.inputs, self.labels = self.dataiter_mnist()
        self.inputs = self.inputs.to(device)
        self.labels = self.labels.to(device)
        self.test_input, self.test_label = self.testiter_mnist()
        self.test_input = self.test_input.to(device)
        self.test_label = self.test_label.to(device)

        # For plotting
        self.train_pred = []
        self.test_pred = []
        self.train_acc = []
        self.test_acc = []
        self.train_cost = []
        self.test_cost = []
        self.train_loss = []
        self.test_loss = []
        self.targets = []
        self.output_activity_train = []
        self.output_activity_test = []
        self.act_func = {
            'act1': [],
            'act2': [],
            'act1_mean': [],
            'act2_mean': [],
            'act1_std': [],
            'act2_std': [],
            'act3': [],
            'act3_mean': [],
            'act3_std': []
        }
        self.test_act_func = {
            'act1': [],
            'act2': [],
            'act1_mean': [],
            'act2_mean': [],
            'act1_std': [],
            'act2_std': [],
            'act3': [],
            'act3_mean': [],
            'act3_std': []
        }

        self.targets.append(self.labels.cpu().numpy())

        # Covariance noise matrix
        self.cov = 0.0
        self.length = 0
        for key in self.conv_net.state_dict().keys():
            self.length += self.conv_net.state_dict()[key].nelement()

    def create_individual(self):
        # get weights, biases from networks and flatten them
        # convolutional network parameters ###
        conv_ensembles = []
        with torch.no_grad():
            for _ in range(self.n_ensembles):
                conv_ensembles.append(
                    np.random.normal(0, 0.1, size=self.length))
            return dict(conv_params=torch.as_tensor(np.array(conv_ensembles),
                                                    device=device),
                        targets=self.labels,
                        input=self.inputs.squeeze())

    def load_model(self, path='conv_params.npy'):
        print('Loading model from path: {}'.format(path))
        conv_params = np.load(path).item()
        conv_ensembles = conv_params.get('ensemble')
        return dict(conv_params=torch.as_tensor(np.array(conv_ensembles),
                                                device=device),
                    targets=self.labels,
                    input=self.inputs.squeeze())

    @staticmethod
    def _he_init(weights, gain=0):
        """
        He- or Kaiming- initialization as in He et al., "Delving deep into
        rectifiers: Surpassing human-level performance on ImageNet
        classification". Values are sampled from
        :math:`\\mathcal{N}(0, \\text{std})` where

        .. math::
        \text{std} = \\sqrt{\\frac{2}{(1 + a^2) \\times \text{fan\\_in}}}

        Note: Only for the case that the non-linearity of the network
            activation is `relu`

        :param weights, tensor
        :param gain, additional scaling factor, Default is 0
        :return: numpy nd array, random array of size `weights`
        """
        fan_in = torch.nn.init._calculate_correct_fan(weights, 'fan_in')
        stddev = np.sqrt(2. / fan_in * (1 + gain**2))
        return np.random.normal(0, stddev, weights.numel())

    def set_parameters(self, ensembles):
        # set the new parameter for the network
        conv_params = ensembles.mean(0)
        ds = self._shape_parameter_to_conv_net(conv_params)
        self.conv_net.set_parameter(ds)
        print('---- Train -----')
        print('Generation ', self.generation)
        generation_change = 8
        with torch.no_grad():
            inputs = self.inputs.to(device)
            labels = self.labels.to(device)

            if self.generation % generation_change == 0:
                self.inputs, self.labels = self.dataiter_mnist()
                self.inputs = self.inputs.to(device)
                self.labels = self.labels.to(device)
                print('New MNIST set used at generation {}'.format(
                    self.generation))
                # append the outputs
                self.targets.append(self.labels.cpu().numpy())

            outputs, act1, act2 = self.conv_net(inputs)
            act3 = outputs
            self.act_func['act1'] = act1.cpu().numpy()
            self.act_func['act2'] = act2.cpu().numpy()
            self.act_func['act3'] = act3.cpu().numpy()
            self.act_func['act1_mean'].append(act1.mean().item())
            self.act_func['act2_mean'].append(act2.mean().item())
            self.act_func['act3_mean'].append(act3.mean().item())
            self.act_func['act1_std'].append(act1.std().item())
            self.act_func['act2_std'].append(act2.std().item())
            self.act_func['act3_std'].append(act3.std().item())
            self.output_activity_train.append(
                F.softmax(outputs, dim=1).cpu().numpy())
            conv_loss = self.criterion(outputs, labels).item()
            self.train_loss.append(conv_loss)
            train_cost = _calculate_cost(
                _encode_targets(labels, 10),
                F.softmax(outputs, dim=1).cpu().numpy(), 'MSE')
            train_acc = score(
                labels.cpu().numpy(),
                np.argmax(F.softmax(outputs, dim=1).cpu().numpy(), 1))
            print('Cost: ', train_cost)
            print('Accuracy: ', train_acc)
            print('Loss:', conv_loss)
            self.train_cost.append(train_cost)
            self.train_acc.append(train_acc)
            self.train_pred.append(
                np.argmax(F.softmax(outputs, dim=1).cpu().numpy(), 1))

            print('---- Test -----')
            test_output, act1, act2 = self.conv_net(self.test_input)
            test_loss = self.criterion(test_output, self.test_label).item()
            self.test_act_func['act1'] = act1.cpu().numpy()
            self.test_act_func['act2'] = act2.cpu().numpy()
            self.test_act_func['act1_mean'].append(act1.mean().item())
            self.test_act_func['act2_mean'].append(act2.mean().item())
            self.test_act_func['act3_mean'].append(test_output.mean().item())
            self.test_act_func['act1_std'].append(act1.std().item())
            self.test_act_func['act2_std'].append(act2.std().item())
            self.test_act_func['act3_std'].append(test_output.std().item())
            test_output = test_output.cpu().numpy()
            self.test_act_func['act3'] = test_output
            test_acc = score(self.test_label.cpu().numpy(),
                             np.argmax(test_output, 1))
            test_cost = _calculate_cost(
                _encode_targets(self.test_label.cpu().numpy(), 10),
                test_output, 'MSE')
            print('Test accuracy', test_acc)
            print('Test loss: {}'.format(test_loss))
            self.test_acc.append(test_acc)
            self.test_pred.append(np.argmax(test_output, 1))
            self.test_cost.append(test_cost)
            self.output_activity_test.append(test_output)
            self.test_loss.append(test_loss)
            print('-----------------')
            conv_params = []
            for idx, c in enumerate(ensembles):
                ds = self._shape_parameter_to_conv_net(c)
                self.conv_net.set_parameter(ds)
                params, _, _ = self.conv_net(inputs)
                conv_params.append(params.t().cpu().numpy())

            outs = {
                'conv_params': torch.tensor(conv_params).to(device),
                'conv_loss': float(conv_loss),
                'input': self.inputs.squeeze(),
                'targets': self.labels
            }
        return outs

    def _shape_parameter_to_conv_net(self, params):
        param_dict = dict()
        start = 0
        for key in self.conv_net.state_dict().keys():
            shape = self.conv_net.state_dict()[key].shape
            length = self.conv_net.state_dict()[key].nelement()
            end = start + length
            param_dict[key] = params[start:end].reshape(shape)
            start = end
        return param_dict