Exemple #1
0
    def run_epoch(self, data, batch_size):
        ravel = lambda x, y: np.ravel(x) - np.ravel(y)
        sumall = lambda x, y: np.sum(ravel(x, y)**2)

        for batched_x, batched_y in batch_data(data, batch_size):

            input_data = self.process_x(batched_x)
            target_data = self.process_y(batched_y)

            global_w = self.gl_ws
            local_w = self.get_params()
            first = True
            tup_diff = []
            for local_v, gl_v in zip(local_w, global_w):
                if first:
                    first = False
                else:
                    tup_diff.append(sumall(local_v, gl_v))

            with self.graph.as_default():
                self.sess.run(self.train_op,
                              feed_dict={
                                  self.prox_term: tup_diff,
                                  self.features: input_data,
                                  self.labels: target_data
                              })
Exemple #2
0
    def run_epoch(self, data, batch_size):

        for batched_x, batched_y in batch_data(data,
                                               batch_size,
                                               seed=self.seed):

            input_data = self.process_x(batched_x)
            target_data = self.process_y(batched_y)

            with self.graph.as_default():
                self.sess.run(self.train_op,
                              feed_dict={
                                  self.features: input_data,
                                  self.labels: target_data
                              })
            print("sssssssssss\n\n\n", self.features)
            print(
                "*********\n\n",
                self.sess.run(self.outputs,
                              feed_dict={
                                  self.features: input_data,
                                  self.labels: target_data
                              })[:, :, :].shape)
            predictions = self.sess.run(self.pred,
                                        feed_dict={
                                            self.features: input_data,
                                            self.labels: target_data
                                        })
            print("*****PREDICTIONS****\n\n", batched_x[0],
                  self.softmax(predictions[0]), target_data[1], batched_y)
            self.anirban = self.sess.run(
                tf.reduce_mean(
                    tf.nn.softmax_cross_entropy_with_logits_v2(
                        logits=predictions, labels=target_data)))
            print("*****LOSS****\n\n", predictions, self.anirban)
Exemple #3
0
    def train(self, data, num_epochs=1, batch_size=10):
        """
        Trains the client model.
        Args:
            data: Dict of the form {'x': [list], 'y': [list]}.
            num_epochs: Number of epochs to train.
            batch_size: Size of training batches.
        Return:
            comp: Number of FLOPs computed while training given data
            update: List of np.ndarray weights, with each weight array
                corresponding to a variable in the resulting graph
        """
        with self.graph.as_default():
            init_values = [self.sess.run(v) for v in tf.trainable_variables()]

        batched_x, batched_y = batch_data(data, batch_size)
        for _ in range(num_epochs):
            for i, raw_x_batch in enumerate(batched_x):
                input_data = self.process_x(raw_x_batch)
                raw_y_batch = batched_y[i]
                target_data = self.process_y(raw_y_batch)
                with self.graph.as_default():
                    self.sess.run(self.train_op,
                                  feed_dict={
                                      self.features: input_data,
                                      self.labels: target_data
                                  })
        with self.graph.as_default():
            update = [self.sess.run(v) for v in tf.trainable_variables()]
            update = [
                np.subtract(update[i], init_values[i])
                for i in range(len(update))
            ]
        comp = num_epochs * len(batched_y) * batch_size * self.flops
        return comp, update
Exemple #4
0
    def test(self, eval_data, train_data=None):
        """
        Tests the current model on the given data.

        Args:
            eval_data: dict of the form {'x': [list], 'y': [list]}
            train_data: None or same format as eval_data. If None, do not measure statistics on train_data.
        Return:
            dict of metrics that will be recorded by the simulation.
        """
        data_lst = [eval_data] if train_data is None else [eval_data, train_data]
        output = {'eval': [-float('inf'), -float('inf')], 'train': [-float('inf'), -float('inf')]}
        for data, data_type in zip(data_lst, ['eval', 'train']):
            total_loss, total_correct, count = 0.0, 0, 0
            batched_x, batched_y = batch_data(data, self.max_batch_size, shuffle=False, eval_mode=True)
            for x, y in zip(batched_x, batched_y):
                x_vecs = self.process_x(x)
                labels = self.process_y(y)
                with self.graph.as_default():
                    loss, correct = self.sess.run(
                        [self.loss_op, self.eval_metric_ops],
                        feed_dict={self.features: x_vecs, self.labels: labels}
                    )
                total_loss += loss * len(y)  # loss returns average over batch
                total_correct += correct  # eval_op returns sum over batch
                count += len(y)
            loss = total_loss / count
            acc = total_correct / count
            output[data_type] = [loss, acc]

        return {OptimLoggingKeys.TRAIN_LOSS_KEY: output['train'][0],
                OptimLoggingKeys.TRAIN_ACCURACY_KEY: output['train'][1],
                OptimLoggingKeys.EVAL_LOSS_KEY: output['eval'][0],
                OptimLoggingKeys.EVAL_ACCURACY_KEY: output['eval'][1]
                }
Exemple #5
0
 def train_tau(self, data, num_tau=1, batch_size=10):
     # print
     count = 0
     gradient = []
     # while count <= num_tau:
     for batched_x, batched_y in batch_data(data, batch_size):
         random_num = np.random.random()
         # print('random number', random_num)
         if count > num_tau:
             break
         if random_num <= 0.8:
             input_data = self.process_x(batched_x)
             target_data = self.process_y(batched_y)
             with self.graph.as_default():
                 # print('start gradient')
                 _, gradient = self.sess.run(
                     [self.train_op, self.gradient_op],
                     feed_dict={
                         self.features: input_data,
                         self.labels: target_data
                     })
                 gradient += np.array(gradient)
                 # print('end gradient')
             count += 1
             # print(count)
     update = self.get_params()
     accumulative_gradients = gradient
     comp = num_tau * batch_size * self.flops
     return comp, update, accumulative_gradients
Exemple #6
0
    def train(self, data, num_epochs=1, batch_size=10):
        """
        Trains the client model.

        Args:
            data: Dict of the form {'x': [list], 'y': [list]}.
            num_epochs: Number of epochs to train.
            batch_size: Size of training batches.
        Return:
            comp: Number of FLOPs computed while training given data
            update: List of np.ndarray weights, with each weight array
                corresponding to a variable in the resulting graph
        """
	    # intialize as server model.
        with self.graph.as_default():
            all_vars = tf.trainable_variables()
            for v in all_vars:
                v.load(self.init_vals[v.name], self.sess)

        with self.graph.as_default():
            init_values = [self.sess.run(v) for v in tf.trainable_variables()]

        delta_values = [np.zeros(np.shape(init_values[i])) for i in range(len(init_values))]

        batched_x, batched_y = batch_data(data, batch_size)
        #run_metadata = tf.RunMetadata()
        cnt = 0
        for _ in range(num_epochs):
            for i, raw_x_batch in enumerate(batched_x):
                cnt += 1
                input_data = self.process_x(raw_x_batch)
                raw_y_batch = batched_y[i]
                target_data = self.process_y(raw_y_batch)
                with self.graph.as_default():

                    self.sess.run(
                        self.train_op,
                        feed_dict={self.features: input_data, self.labels: target_data, self.is_train:True},
					    #options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
					    #run_metadata=run_metadata
                        )
                    #self.profiler.add_step(i, run_metadata)
                    #self.profiler.profile_graph(options=opts)
                    #opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
                    #profiler.profile_operations(options=opts)

        #profile_code_opt_builder.select(['micros','occurrence'])
        #profile_code_opt_builder.order_by('micros')
        #profile_code_opt_builder.with_max_depth(10)
        #self.profiler.profile_operations(profile_code_opt_builder.build())
        #self.profiler.profile_python(profile_code_opt_builder.build())

                    
        with self.graph.as_default():
            update = [self.sess.run(v) for v in tf.trainable_variables()]
            update = [np.subtract(update[i], init_values[i]) for i in range(len(update))]

        comp = num_epochs * len(batched_y) * batch_size * self.flops

        return comp, update
Exemple #7
0
    def test(self, data, batch_size=10):
        """
        Tests the current model on the given data.

        Args:
            data: dict of the form {'x': [list], 'y': [list]}
        Return:
            dict of metrics that will be recorded by the simulation.
        """
        tot_wtd_acc = 0
        tot_wtd_loss = 0
        count = 0
        for batched_x, batched_y in batch_data(data,
                                               batch_size,
                                               seed=self.seed):
            input_data = self.process_x(batched_x)
            target_data = self.process_y(batched_y)
            with self.graph.as_default():
                batch_tot_acc, batch_loss = self.sess.run(
                    [self.eval_metric_ops, self.loss],
                    feed_dict={
                        self.features: input_data,
                        self.labels: target_data
                    })
            tot_wtd_acc += batch_tot_acc
            tot_wtd_loss += batch_loss * input_data.shape[0]
            count += input_data.shape[0]

        acc = tot_wtd_acc / count
        loss = tot_wtd_loss / count
        return {ACCURACY_KEY: acc, 'loss': loss}
Exemple #8
0
    def train(self, data, server_w, server_grad, num_epochs=1, batch_size=10):
        """
        Trains the client model.
        Args:
            data: Dict of the form {'x': [list], 'y': [list]}.
            num_epochs: Number of epochs to train.
            batch_size: Size of training batches.
        Return:
            comp: Number of FLOPs computed while training given data
            update: List of np.ndarray weights, with each weight array
                corresponding to a variable in the resulting graph
        """
        with self.graph.as_default():
            init_values = [self.sess.run(v) for v in tf.trainable_variables()]

        batched_x, batched_y = batch_data(data, batch_size)
        # calculate term b gradient using (1) server weights and (2) client data
        X = self.process_x(batched_x[0])
        y = self.process_y(batched_y[0])
        term_b = self.fetch_grad(server_w, X, y)

        # compute term c gradient using (1) server weights and (2) server data
        term_c = server_grad
        client_w = server_w

        # for _ in range(num_epochs):
        #     for i, raw_x_batch in enumerate(batched_x):
        #         input_data = self.process_x(raw_x_batch)
        #         raw_y_batch = batched_y[i]
        #         target_data = self.process_y(raw_y_batch)
        #         with self.graph.as_default():
        #             self.sess.run(
        #                 self.train_op,
        #                 feed_dict={self.features: input_data, self.labels: target_data}
        #             )
        #     compute term a gradient using (1) client weights and (2) client data
        # with self.graph.as_default():
        #     update = [self.sess.run(v) for v in tf.trainable_variables()]
        #     update = [np.subtract(update[i], init_values[i]) for i in range(len(update))]
        # comp = num_epochs * len(batched_y) * batch_size * self.flops
        # return comp, update

        for _ in range(num_epochs):
            term_a = self.fetch_grad(client_w, X, y)

            # update client weights
            client_w = client_w - self.lr * (term_a - term_b + term_c)

        ##TODO send client_w to the server to get get new avg weight
        ## Get back w_t+1 from server (after average)
        ## Recompute new local gradient
        ## Send gradient to the sever to get average gradient

        comp = 0
        return comp, client_w
Exemple #9
0
 def run_epochs(self, data, batch_size):
     for batched_x, batched_y in batch_data(data,
                                            batch_size,
                                            seed=self.seed):
         input_data = self.preprocess_x(batched_x)
         target_data = self.preprocess_y(batched_y)
         self.trainer0.zero_grad()
         y_hats = self.net(input_data)
         lss = self.losses(y_hats, target_data.type(torch.LongTensor))
         lss.backward()
         self.trainer0.step()
Exemple #10
0
    def _run_epoch(self, data, batch_size):
        for batched_x, batched_y in batch_data(data, batch_size, self.seed):
            input_data = self.process_x(batched_x)
            target_data = self.process_y(batched_y)

            with self.graph.as_default():
                self.sess.run(self.train_op,
                              feed_dict={
                                  self.features: input_data,
                                  self.labels: target_data
                              })
Exemple #11
0
    def run_epochs(self, seed, data, batch_size):
        for batched_x, batched_y in batch_data(data, batch_size, seed):
            input_data = self.preprocess_x(batched_x)
            target_data = self.preprocess_y(batched_y)
            num_batch = len(batched_y)

            # Set MXNET_ENFORCE_DETERMINISM=1 to avoid difference in
            # calculation precision.
            with autograd.record():
                y_hats = self.net(input_data)
                ls = self.loss(y_hats, target_data)
                ls.backward()

            self.trainer.step(num_batch)
Exemple #12
0
 def solve_inner(self, data, num_epochs, batch_size):
     '''
     Perform the inner optimization routine
     '''
     self.reset_meta()
     for _ in range(num_epochs):
         for X, y in batch_data(data, batch_size):
             # Perform the gradient step
             self.zero_grad()
             loss = self.client_model.loss_fn(self.client_model(X), y) + self.prox_term()
             loss.backward()
             self.step()
     soln = [p.clone() for p in self.client_model.parameters()]
     comp = num_epochs * (len(data['y'])//batch_size) * batch_size * self.flops
     return soln, comp
Exemple #13
0
    def run_epoch(self, data, batch_size):

        for batched_x, batched_y in batch_data(data, batch_size, seed=self.seed):
            
            input_data = self.process_x(batched_x)
            target_data = self.process_y(batched_y)
            
            with self.graph.as_default():
                    _, tot_acc, loss = self.sess.run([self.train_op, self.eval_metric_ops, self.loss],
                    feed_dict={
                        self.features: input_data,
                        self.labels: target_data
                    })
        acc = float(tot_acc) / input_data.shape[0]
        return {'acc': acc, 'loss': loss}
Exemple #14
0
    def run_epoch(self, data, batch_size):
        # def run_epoch(self, data, batch_size, batch_num):
        # batch_index = 0
        for batched_x, batched_y in batch_data(data, batch_size, seed=self.seed):
            # if batch_index == batch_num:
            #     break
            input_data = self.process_x(batched_x)
            target_data = self.process_y(batched_y)

            with self.graph.as_default():
                self.sess.run(self.train_op,
                    feed_dict={
                        self.features: input_data,
                        self.labels: target_data
                    })
Exemple #15
0
 def test(self, data):
     '''
     Test the model on given data
     '''
     num_correct = 0
     num_samples = 0
     losses = []
     with torch.no_grad():
         for X, y in batch_data(data, batch_size=32):
             out = self.forward(X)
             preds = torch.argmax(out, dim=1)
             num_correct += (preds == y).int().sum().item()
             num_samples += X.shape[0]
             losses.append(self.loss_fn(out, y).item())
     loss = np.mean(losses)
     return loss, num_correct, num_samples
Exemple #16
0
 def run_epoch(self, data, batch_size):
     running_loss = 0
     count = 1
     self.net.train()
     for batched_x, batched_y in batch_data(data, batch_size, seed=self.seed):
         self.net.to("cuda:0")
         input_data = self.process_x(batched_x).to("cuda:0")
         target_data = self.process_y(batched_y).to("cuda:0")
         self._optimizer.zero_grad()
         self.outputs = self.net(input_data)
         loss = self.criterion(self.outputs, target_data)
         loss.backward()
         self._optimizer.step()
         running_loss += loss
         count += 1
     self.net.to("cpu")  # just to save gpu memory
     return running_loss / count
Exemple #17
0
    def train(self, data, num_epochs=1, batch_size=10, lr=None):
        """
        Trains the client model.

        Args:
            data: Dict of the form {'x': [list], 'y': [list]}.
            num_epochs: Number of epochs to train.
            batch_size: Size of training batches.
        Return:
            comp: Number of FLOPs computed while training given data
            update: List of np.ndarray weights, with each weight array
                corresponding to a variable in the resulting graph
            averaged_loss: average of stochastic loss in the final epoch
        """
        if lr is None:
            lr = self.lr
        averaged_loss = 0.0

        batched_x, batched_y = batch_data(data,
                                          batch_size,
                                          rng=self.rng,
                                          shuffle=True)
        if self.optimizer.w is None:
            self.optimizer.initialize_w()

        for epoch in range(num_epochs):
            total_loss = 0.0

            for i, raw_x_batch in enumerate(batched_x):
                input_data = self.process_x(raw_x_batch)
                raw_y_batch = batched_y[i]
                target_data = self.process_y(raw_y_batch)

                loss = self.optimizer.run_step(input_data, target_data)
                total_loss += loss
            averaged_loss = total_loss / len(batched_x)
        # print('inner opt:', epoch, averaged_loss)

        self.optimizer.end_local_updates()  # required for pytorch models
        update = np.copy(self.optimizer.w - self.optimizer.w_on_last_update)

        self.optimizer.update_w()

        comp = num_epochs * len(batched_y) * batch_size * self.flops
        return comp, update, averaged_loss
Exemple #18
0
 def solve_inner(self, data, num_epochs, batch_size):
     '''
     Perform the inner optimization routine
     '''
     self.reset_meta()
     for _ in range(num_epochs):
         for X, y in batch_data(data, batch_size):
             # Perform the gradient step
             self.zero_grad()
             loss = self.client_model.loss_fn(self.client_model(X), y)
             loss.backward()
             self.step()
             # Track the gradient for global step
             self.add_grad([p.grad for p in self.client_model.parameters()])
     soln = self.global_step()
     #comp = num_epochs * (len(data['y'])//batch_size) * batch_size * self.flops
     comp = 0  # TODO
     return soln, comp
Exemple #19
0
    def run_epoch(self, data, batch_size):

        for batched_x, batched_y in batch_data(data,
                                               batch_size,
                                               seed=self.seed):

            input_data = self.process_x(batched_x)
            target_data = self.process_y(batched_y)

            # print(f"np.round(input_data, 6) {np.round(input_data, 6)}")
            # print(target_data)

            with self.graph.as_default():
                self.sess.run(self.train_op,
                              feed_dict={
                                  self.features: np.round(input_data, 6),
                                  self.labels: target_data
                              })
    def prepare_test(self, data, batch_size, min_loss):
        # try to reach convergence before testing
        # return loss for diag
        loss_list = list()
        first = True
        for batched_x, batched_y in batch_data(data, batch_size):
            input_data = self.process_x(batched_x)
            target_data = self.process_y(batched_y)

            with self.graph.as_default():
                _, loss = self.sess.run([self.train_op, self.loss],
                    feed_dict={
                        self.features: input_data,
                        self.labels: target_data
                    })  
                
            loss_list.append(loss)
            # A conditon if new_loss is almost stable
        return loss_list
Exemple #21
0
    def solve_inner(self, data, ref_params, ref_grad, num_epochs, batch_size):
        '''
        Perform the inner optimization routine
        '''
        self.reset_meta()

        cur_param = list(self.client_model.parameters())
        for _ in range(num_epochs):
            for X, y in batch_data(data, batch_size):
                # Update the parameter iterate
                cur_param = self.step(
                    cur_param, self.client_model.gradient(cur_param, X, y),
                    self.client_model.gradient(ref_params, X, y), ref_grad)
                # Update the gradient sum
                self.add_grad_at_ref(
                    self.client_model.gradient(self.orig, X, y))

        comp = 0  # TODO
        soln = [p.clone() for p in cur_param]
        return (soln, self.grad_mean()), comp
Exemple #22
0
    def run_epoch(self, data, batch_size):
        for batched_x, batched_y in batch_data(data, batch_size):

            input_data = self.process_x(batched_x)
            target_data = self.process_y(batched_y)

            with self.graph.as_default():
                grads = self.sess.run(self.grad_op,
                                      feed_dict={
                                          self.features: input_data,
                                          self.labels: target_data
                                      })
                grads = [g[0] for g in grads
                         ]  # Each element of grads is (gradient, variables).
                # TODO:R Remove run(train_op)
                # self.sess.run(self.train_op,
                #     feed_dict={
                #         self.features: input_data,
                #         self.labels: target_data
                #     })
        return grads
Exemple #23
0
    def train(self, data, num_epochs=1, batch_size=10, lr=None):
        """
        Trains the client model.

        Args:
            data: Dict of the form {'x': [list], 'y': [list]}.
            num_epochs: Number of epochs to train.
            batch_size: Size of training batches.
        Return:
            comp: Number of FLOPs computed while training given data
            update: List of np.ndarray weights, with each weight array
                corresponding to a variable in the resulting graph
            averaged_loss: average of stochastic loss in the final epoch
        """
        if lr is None:
            lr = self.lr
        averaged_loss = 0.0
        with self.graph.as_default():
            init_values = [self.sess.run(v) for v in tf.trainable_variables()]

        batched_x, batched_y = batch_data(data, batch_size, rng=self.rng, shuffle=True)
        for epoch in range(num_epochs):
            total_loss = 0.0
            for i, raw_x_batch in enumerate(batched_x):
                input_data = self.process_x(raw_x_batch)
                raw_y_batch = batched_y[i]
                target_data = self.process_y(raw_y_batch)
                with self.graph.as_default():
                    loss, _ = self.sess.run(
                        [self.loss_op, self.train_op],
                        feed_dict={self.features: input_data, self.labels: target_data, self.learning_rate_tensor: lr}
                    )
                total_loss += loss
            averaged_loss = total_loss / len(batched_x)
        with self.graph.as_default():
            update = [self.sess.run(v) for v in tf.trainable_variables()]
            update = [np.subtract(update[i], init_values[i]) for i in range(len(update))]
        comp = num_epochs * len(batched_y) * batch_size * self.flops
        return comp, update, averaged_loss
Exemple #24
0
    def train(self, data, num_epochs=1, batch_size=10):
        """
        Trains the client model.

        Args:
            data: Dict of the form {'x': [list], 'y': [list]}.
            num_epochs: Number of epochs to train.
            batch_size: Size of training batches.
        Return:
            comp: Number of FLOPs computed while training given data
            update: List of np.ndarray weights, with each weight array
                corresponding to a variable in the resulting graph
        """
        # intialize as server model.
        with self.graph.as_default():
            all_vars = tf.trainable_variables()
            for v in all_vars:
                v.load(self.init_vals[v.name], self.sess)

        with self.graph.as_default():
            init_values = [self.sess.run(v) for v in tf.trainable_variables()]
            delta_values = [
                np.zeros(np.shape(init_values[i]))
                for i in range(len(init_values))
            ]

        batched_x, batched_y = batch_data(data, batch_size)
        cnt = 0
        sum_samples = 0
        for _ in range(num_epochs):
            for i, raw_x_batch in enumerate(batched_x):
                input_data = self.process_x(raw_x_batch)
                raw_y_batch = batched_y[i]
                target_data = self.process_y(raw_y_batch)
                with self.graph.as_default():
                    _, var_grad = self.sess.run(
                        [self.train_op, self.var_grad],
                        feed_dict={
                            self.features: input_data,
                            self.labels: target_data,
                            self.is_train: True
                        })
                    #self.sess.run(
                    #    self.train_op,
                    #feed_dict={self.features: input_data, self.labels: target_data, self.is_train:True}
                    #)
                    delta_values = [
                        np.add(delta_values[i],
                               pow(self.gamma, cnt) * var_grad[i])
                        for i in range(len(var_grad))
                    ]
                sum_samples += pow(self.gamma, cnt)
                cnt += 1

        with self.graph.as_default():
            update = [self.sess.run(v) for v in tf.trainable_variables()]
            update = [
                np.subtract(update[i], init_values[i])
                for i in range(len(update))
            ]
            delta_values = [
                -delta_values[i] / sum_samples * cnt
                for i in range(len(delta_values))
            ]

        #comp = num_epochs * len(batched_y) * batch_size * self.flops
        comp = num_epochs * len(batched_y) * batch_size
        #return comp, update
        return comp, delta_values
Exemple #25
0
    def test(self,
             eval_data,
             train_data=None,
             split_by_user=True,
             train_users=True):
        """
        Tests the current model on the given data.
        Args:
            eval_data: dict of the form {'x': [list], 'y': [list]}
            train_data: None or same format as eval_data. If None, do not measure statistics on train_data.
        Return:
            dict of metrics that will be recorded by the simulation.
        """
        if split_by_user:
            output = {
                'eval': [-float('inf'), -float('inf')],
                'train': [-float('inf'), -float('inf')]
            }

            if self.optimizer.w is None:
                self.optimizer.initialize_w()

            total_loss, total_correct, count = 0.0, 0, 0
            batched_x, batched_y = batch_data(eval_data,
                                              self.max_batch_size,
                                              shuffle=False,
                                              eval_mode=True)
            for x, y in zip(batched_x, batched_y):
                x_vecs = self.process_x(x)
                labels = self.process_y(y)

                loss = self.optimizer.loss(x_vecs, labels)
                correct = self.optimizer.correct(x_vecs, labels)

                total_loss += loss * len(y)  # loss returns average over batch
                total_correct += correct  # eval_op returns sum over batch
                count += len(y)
                # counter_1 += 1
            loss = total_loss / count
            acc = total_correct / count
            if train_users:
                output['train'] = [loss, acc]
            else:
                output['eval'] = [loss, acc]

            return {
                ACCURACY_KEY: output['eval'][1],
                OptimLoggingKeys.TRAIN_LOSS_KEY: output['train'][0],
                OptimLoggingKeys.TRAIN_ACCURACY_KEY: output['train'][1],
                OptimLoggingKeys.EVAL_LOSS_KEY: output['eval'][0],
                OptimLoggingKeys.EVAL_ACCURACY_KEY: output['eval'][1]
            }
        else:
            data_lst = [eval_data
                        ] if train_data is None else [eval_data, train_data]
            output = {
                'eval': [-float('inf'), -float('inf')],
                'train': [-float('inf'), -float('inf')]
            }

            if self.optimizer.w is None:
                self.optimizer.initialize_w()
            # counter_0 = 0
            for data, data_type in zip(data_lst, ['eval', 'train']):
                # counter_1 = 0
                total_loss, total_correct, count = 0.0, 0, 0
                batched_x, batched_y = batch_data(data,
                                                  self.max_batch_size,
                                                  shuffle=False,
                                                  eval_mode=True)
                for x, y in zip(batched_x, batched_y):
                    x_vecs = self.process_x(x)
                    labels = self.process_y(y)

                    loss = self.optimizer.loss(x_vecs, labels)
                    correct = self.optimizer.correct(x_vecs, labels)

                    total_loss += loss * len(
                        y)  # loss returns average over batch
                    total_correct += correct  # eval_op returns sum over batch
                    count += len(y)
                    # counter_1 += 1
                loss = total_loss / count
                acc = total_correct / count
                output[data_type] = [loss, acc]
                # counter_1 += 1

            return {
                ACCURACY_KEY: output['eval'][1],
                OptimLoggingKeys.TRAIN_LOSS_KEY: output['train'][0],
                OptimLoggingKeys.TRAIN_ACCURACY_KEY: output['train'][1],
                OptimLoggingKeys.EVAL_LOSS_KEY: output['eval'][0],
                OptimLoggingKeys.EVAL_ACCURACY_KEY: output['eval'][1]
            }