def processData(device_index, start_samples, samples, federated,
                full_data_size, number_of_batches, parameter_server,
                sample_distribution):
    pause(5)  # PS server (if any) starts first
    checkpointpath1 = 'results/model{}.h5'.format(device_index)
    outfile = 'results/dump_train_variables{}.npz'.format(device_index)
    outfile_models = 'results/dump_train_model{}.npy'.format(device_index)
    global_model = 'results/model_global.npy'

    np.random.seed(1)
    tf.random.set_seed(1)  # common initialization

    learning_rate = args.mu
    learning_rate_local = learning_rate

    B = np.ones((devices, devices)) - tf.one_hot(np.arange(devices), devices)
    Probabilities = B[device_index, :] / (devices - 1)
    training_signal = False

    # check for backup variables on start
    if os.path.isfile(checkpointpath1):
        train_start = False

        # backup the model and the model target
        model = models.load_model(checkpointpath1)
        data_history = []
        label_history = []
        local_model_parameters = np.load(outfile_models, allow_pickle=True)
        model.set_weights(local_model_parameters.tolist())

        dump_vars = np.load(outfile, allow_pickle=True)
        frame_count = dump_vars['frame_count']
        epoch_loss_history = dump_vars['epoch_loss_history'].tolist()
        running_loss = np.mean(epoch_loss_history[-5:])
        epoch_count = dump_vars['epoch_count']
    else:
        train_start = True
        model = create_q_model()
        data_history = []
        label_history = []
        frame_count = 0
        # Experience replay buffers
        epoch_loss_history = []
        epoch_count = 0
        running_loss = math.inf

    training_end = False

    # create a data object (here radar data)
    if args.noniid_assignment == 0:
        data_handle = CIFARData(device_index, start_samples, samples,
                                full_data_size, args.random_data_distribution)
    else:
        data_handle = CIFARData_task(device_index, start_samples, samples,
                                     full_data_size,
                                     args.random_data_distribution)
    # create a consensus object
    cfa_consensus = CFA_process(devices, device_index, args.N)

    while True:  # Run until solved
        # collect 1 batch
        frame_count += 1
        obs, labels = data_handle.getTrainingData(batch_size)
        data_batch = preprocess_observation(obs, batch_size)

        # Save data and labels in the current learning session
        data_history.append(data_batch)
        label_history.append(labels)

        # Local learning update every "number of batches" batches
        if frame_count % number_of_batches == 0 and not training_signal:
            epoch_count += 1
            for i in range(number_of_batches):
                data_sample = np.array(data_history[i])
                label_sample = np.array(label_history[i])

                # Create a mask to calculate loss
                masks = tf.one_hot(label_sample, n_outputs)

                with tf.GradientTape() as tape:
                    # Train the model on data samples
                    classes = model(data_sample)
                    # Apply the masks
                    class_v = tf.reduce_sum(tf.multiply(classes, masks),
                                            axis=1)
                    # Calculate loss
                    loss = loss_function(label_sample, class_v)

                # Backpropagation
                grads = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))

            del data_history
            del label_history
            data_history = []
            label_history = []

            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     epoch_loss_history=epoch_loss_history,
                     training_end=training_end,
                     epoch_count=epoch_count,
                     loss=running_loss)
            np.save(outfile_models, model_weights)

            #  Consensus round
            # update local model
            cfa_consensus.update_local_model(model_weights)
            # neighbor = cfa_consensus.get_connectivity(device_index, args.N, devices) # fixed neighbor
            # neighbor = np.random.choice(np.arange(devices), args.N, p=Probabilities, replace=False) # choose neighbor
            neighbor = np.random.choice(np.arange(devices),
                                        args.N,
                                        replace=False)  # choose neighbor
            while neighbor == device_index:
                neighbor = np.random.choice(np.arange(devices),
                                            args.N,
                                            replace=False)

            if not train_start:
                if federated and not training_signal:
                    eps_c = 1 / (args.N + 1)
                    # apply consensus for model parameter
                    print(
                        "Consensus from neighbor {} for device {}, local loss {:.2f}"
                        .format(neighbor, device_index, loss.numpy()))
                    model.set_weights(
                        cfa_consensus.federated_weights_computing(
                            neighbor, args.N, frame_count, eps_c, max_lag))
                    if cfa_consensus.getTrainingStatusFromNeightbor():
                        # a neighbor completed the training, with loss < target, transfer learning is thus applied (the device will copy and reuse the same model)
                        training_signal = True  # stop local learning, just do validation
            else:
                print("Warm up")
                train_start = False

            # check if parameter server is enabled
            stop_aggregation = False
            if parameter_server:
                pause(refresh_server)
                while not os.path.isfile(global_model):
                    # implementing consensus
                    print("waiting")
                    pause(1)
                try:
                    model_global = np.load(global_model, allow_pickle=True)
                except:
                    pause(5)
                    print("retrying opening global model")
                    try:
                        model_global = np.load(global_model, allow_pickle=True)
                    except:
                        print("halting aggregation")
                        stop_aggregation = True

                if not stop_aggregation:
                    # print("updating from global model inside the parmeter server")
                    for k in range(cfa_consensus.layers):
                        # model_weights[k] = model_weights[k]+ 0.5*(model_global[k]-model_weights[k])
                        model_weights[k] = model_global[k]
                    model.set_weights(model_weights.tolist())

            del model_weights

        # validation tool for device 'device_index'
        if epoch_count > validation_start and frame_count % number_of_batches == 0:
            avg_cost = 0.
            for i in range(number_of_batches_for_validation):
                obs_valid, labels_valid = data_handle.getTestData(
                    batch_size, i)
                # obs_valid, labels_valid = data_handle.getRandomTestData(batch_size)
                data_valid = preprocess_observation(np.squeeze(obs_valid),
                                                    batch_size)
                data_sample = np.array(data_valid)
                label_sample = np.array(labels_valid)
                # Create a mask to calculate loss
                masks = tf.one_hot(label_sample, n_outputs)
                classes = model(data_sample)
                # Apply the masks
                class_v = tf.reduce_sum(tf.multiply(classes, masks), axis=1)
                # Calculate loss
                loss = loss_function(label_sample, class_v)
                avg_cost += loss / number_of_batches_for_validation  # Training loss
            epoch_loss_history.append(avg_cost)
            print("Device {} epoch count {}, validation loss {:.2f}".format(
                device_index, epoch_count, avg_cost))
            # mean loss for last 5 epochs
            running_loss = np.mean(epoch_loss_history[-5:])

        if running_loss < target_loss or training_signal:  # Condition to consider the task solved
            print(
                "Solved for device {} at epoch {} with average loss {:.2f} !".
                format(device_index, epoch_count, running_loss))
            training_end = True
            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            # model_target.save(checkpointpath2, include_optimizer=True, save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     epoch_loss_history=epoch_loss_history,
                     training_end=training_end,
                     epoch_count=epoch_count,
                     loss=running_loss)
            np.save(outfile_models, model_weights)
            if federated:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "neighbors": args.N,
                    "active_devices": args.Ka_consensus,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            elif parameter_server:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "active_devices": active_devices_per_round,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            else:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }

            if federated:
                sio.savemat(
                    "results/matlab/CFA_device_{}_samples_{}_devices_{}_active_{}_neighbors_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(device_index, samples, devices, args.Ka_consensus,
                            args.N, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
                sio.savemat(
                    "CFA_device_{}_samples_{}_devices_{}_neighbors_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices, args.N,
                            number_of_batches, batch_size), dict_1)
            elif parameter_server:
                sio.savemat(
                    "results/matlab/FA2_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size, args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
                sio.savemat(
                    "FA2_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size), dict_1)
            else:  # CL
                sio.savemat(
                    "results/matlab/CL_samples_{}_devices_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(samples, devices, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)

            break

        if epoch_count > max_epochs:  # stop simulation
            print("Unsolved for device {} at epoch {}!".format(
                device_index, epoch_count))
            training_end = True
            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            # model_target.save(checkpointpath2, include_optimizer=True, save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     epoch_loss_history=epoch_loss_history,
                     training_end=training_end,
                     epoch_count=epoch_count,
                     loss=running_loss)
            np.save(outfile_models, model_weights)
            if federated:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "neighbors": args.N,
                    "active_devices": args.Ka_consensus,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            elif parameter_server:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "active_devices": active_devices_per_round,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            else:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }

            if federated:
                sio.savemat(
                    "results/matlab/CFA_device_{}_samples_{}_devices_{}_active_{}_neighbors_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(device_index, samples, devices, args.Ka_consensus,
                            args.N, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
                sio.savemat(
                    "CFA_device_{}_samples_{}_devices_{}_neighbors_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices, args.N,
                            number_of_batches, batch_size), dict_1)
            elif parameter_server:
                sio.savemat(
                    "results/matlab/FA2_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size, args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
                sio.savemat(
                    "FA2_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size), dict_1)
            else:  # CL
                sio.savemat(
                    "results/matlab/CL_samples_{}_devices_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(samples, devices, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)

            break
def processData(device_index, number_positions_devices, federated,
                target_server, initialization, update_after_actions):
    pause(5)  # server start first
    checkpointpath1 = 'results/model{}.h5'.format(device_index)
    checkpointpath2 = 'results/model_target{}.h5'.format(device_index)
    outfile = 'results/dump_train_variables{}.npz'.format(device_index)
    outfile_models = 'results/dump_train_model{}.npy'.format(device_index)
    global_target_model = 'results/model_target_global.npy'

    if args.centralized == 0:
        max_steps_per_episode = number_positions_devices  # other option 1000
    else:
        max_steps_per_episode = number_positions_devices * devices  # other option 1000

    n_file_cfa = "models_saved/CFA_robot_{}_number_{}_neighbors_{}_explored_pos_{}_update_{}_run_{}.mat".format(
        device_index, devices, args.N, number_positions_devices,
        update_consensus, args.run)
    n_file_cfa_h5 = "models_saved/CFA_robot_{}_number_{}_neighbors_{}_explored_pos_{}_update_{}_run_{}.h5".format(
        device_index, devices, args.N, number_positions_devices,
        update_consensus, args.run)
    n_file_fa = "models_saved/FA_robot_{}_number_{}_explored_pos_{}_update_{}_run_{}.mat".format(
        device_index, devices, number_positions_devices, update_consensus,
        args.run)
    n_file_fa_h5 = "models_saved/FA_robot_{}_number_{}_explored_pos_{}_update_{}_run_{}.h5".format(
        device_index, devices, number_positions_devices, update_consensus,
        args.run)
    n_file_cl = "models_saved/CL_datacenter_{}_number_{}_explored_pos_{}_update_{}_run{}.mat".format(
        device_index, devices, number_positions_devices, update_after_actions,
        args.run)
    n_file_cl_h5 = "models_saved/CL_datacenter_{}_number_{}_explored_pos_{}_update_{}_run{}.h5".format(
        device_index, devices, number_positions_devices, update_after_actions,
        args.run)
    n_file_isolated = "models_saved/Isolated_robot_{}_number_{}_explored_pos_{}_update_{}_run_{}.mat".format(
        device_index, devices, number_positions_devices, update_after_actions,
        args.run)

    n_file_isolated_h5 = "models_saved/Isolated_robot_{}_number_{}_explored_pos_{}_update_{}_run_{}.h5".format(
        device_index, devices, number_positions_devices, update_after_actions,
        args.run)

    #np.random.seed(1)
    #tf.random.set_seed(1)  # common initialization

    B = np.ones((devices, devices)) - tf.one_hot(np.arange(devices), devices)
    Probabilities = B[device_index, :] / (devices - 1)
    training_signal = False

    # check for backup variables
    if os.path.isfile(checkpointpath1):
        train_start = False

        # backup the model and the model target
        model = models.load_model(checkpointpath1)
        model_target = models.load_model(checkpointpath2)
        local_model_parameters = np.load(outfile_models, allow_pickle=True)
        model.set_weights(local_model_parameters.tolist())

        dump_vars = np.load(outfile, allow_pickle=True)
        frame_count = dump_vars['frame_count']
        action_history = []
        state_history = []
        state_next_history = []
        rewards_history = []
        done_history = []
        episode_reward_history = dump_vars['episode_reward_history'].tolist()
        running_reward = dump_vars['running_reward']
        episode_count = dump_vars['episode_count']
        epsilon = dump_vars['epsilon']
    else:
        train_start = True
        model = create_q_model()
        model_target = create_q_model()
        frame_count = 0
        # Experience replay buffers
        action_history = []
        state_history = []
        state_next_history = []
        rewards_history = []
        done_history = []
        episode_reward_history = []
        running_reward = 0
        episode_count = 0
        epsilon = 1.0  # Epsilon greedy parameter

    training_end = False
    epsilon_min = 0.1  # Minimum epsilon greedy parameter
    epsilon_min_validation = 0.001  # for validation only
    epsilon_max = 1.0  # Maximum epsilon greedy parameter
    epsilon_interval = (
        epsilon_max - epsilon_min
    )  # Rate at which to reduce chance of random action being taken

    print("End loading file")
    print("Set epsilon to {} for device {}".format(epsilon, device_index))

    optimizer = keras.optimizers.Adam(learning_rate=args.mu, clipnorm=1.0)
    file_size = number_positions
    robot_trajectory = RobotTrajectory(filepath, lookuptab, filerewards,
                                       file_size, number_positions_devices)
    robot_trajectory_validation = RobotTrajectory(filepath, lookuptab,
                                                  filerewards, file_size,
                                                  number_positions)
    cfa_consensus = CFA_process(devices, device_index, args.N)
    inizialization_index = 0

    while True:  # Run until solved
        # state = np.array(env.reset())

        if args.centralized == 0:
            [obs, reward, done
             ] = robot_trajectory.initialize(position_initial=initialization)
        else:  # CL learning
            [obs, reward, done] = robot_trajectory.initialize(
                position_initial=initialization[inizialization_index])
            #inizialization_index += 1
            #if inizialization_index % devices == 0:
            #    inizialization_index = 0

        state = preprocess_observation(np.squeeze(obs))
        # episode_reward = reward
        # neighbor = cfa_consensus.get_connectivity(device_index, args.N, devices)
        for timestep in range(1, max_steps_per_episode):
            # env.render(); Adding this line would show the attempts
            # of the agent in a pop up window.
            frame_count += 1
            if args.centralized == 1:  # check data obtained from multiple devices
                if timestep % number_positions_devices == 0:  # reinitialize
                    inizialization_index += 1
                    if inizialization_index % devices == 0:
                        inizialization_index = 0
                    [obs, reward, done] = robot_trajectory.initialize(
                        position_initial=initialization[inizialization_index])
                    state = preprocess_observation(np.squeeze(obs))

            # Use epsilon-greedy for exploration
            if frame_count < epsilon_random_frames or epsilon > np.random.rand(
                    1)[0]:
                # Take random action
                action = np.random.choice(n_outputs)
            else:
                # Predict action Q-values
                # From environment state
                state_tensor = tf.convert_to_tensor(state)
                state_tensor = tf.expand_dims(state_tensor, 0)
                action_probs = model(state_tensor, training=False)
                # Take best action
                action = tf.argmax(action_probs[0]).numpy()

            # Decay probability of taking random action
            epsilon -= epsilon_interval / epsilon_greedy_frames
            epsilon = max(epsilon, epsilon_min)

            # Apply the sampled action in our environment
            dev = 0
            [obs, reward, done] = robot_trajectory.implement(action, dev)
            state_next = preprocess_observation(np.squeeze(obs))
            state_next = np.array(state_next)

            # episode_reward += reward

            # Save actions and states in replay buffer
            action_history.append(action)
            state_history.append(state)
            state_next_history.append(state_next)
            done_history.append(done)
            rewards_history.append(reward)

            # Let's memorize what happened
            # replay_memory.append((state, action, reward, state_next, 1.0 - done))

            state = state_next

            # Update every fourth frame and once batch size is over 32
            if frame_count % update_after_actions == 0 and len(
                    done_history) > batch_size and not training_signal:
                # start = time.time()
                # Get indices of samples for replay buffers
                indices = np.random.choice(range(len(done_history)),
                                           size=batch_size)

                # Using list comprehension to sample from replay buffer
                state_sample = np.array([state_history[i] for i in indices])
                state_next_sample = np.array(
                    [state_next_history[i] for i in indices])
                rewards_sample = [rewards_history[i] for i in indices]
                action_sample = [action_history[i] for i in indices]
                done_sample = tf.convert_to_tensor(
                    [float(done_history[i]) for i in indices])

                # Build the updated Q-values for the sampled future states
                # Use the target model for stability
                future_rewards = model_target.predict(state_next_sample)

                # Bellman equation
                # Q value = reward + discount factor * expected future reward
                updated_q_values = rewards_sample + gamma * tf.reduce_max(
                    future_rewards, axis=1)

                # If final frame set the last value to -1 or max_reward
                # updated_q_values = updated_q_values * (1 - done_sample) - done_sample
                updated_q_values = updated_q_values * (
                    1 - done_sample) + done_sample * max_reward  # test
                # Create a mask so we only calculate loss on the updated Q-values
                masks = tf.one_hot(action_sample, num_actions)

                with tf.GradientTape() as tape:
                    # Train the model on the states and updated Q-values
                    q_values = model(state_sample)

                    # Apply the masks to the Q-values to get the Q-value for action taken
                    q_action = tf.reduce_sum(tf.multiply(q_values, masks),
                                             axis=1)
                    # Calculate loss between new Q-value and old Q-value
                    loss = loss_function(updated_q_values, q_action)

                # Backpropagation
                grads = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))
                #end = time.time()
                #print("Time for 1 minibatch (32 observations): {}".format(end-start))

            if frame_count % update_consensus == 0 and len(
                    done_history) > batch_size:
                model_weights = np.asarray(model.get_weights())
                # update local model
                cfa_consensus.update_local_model(model_weights)
                # neighbor = cfa_consensus.get_connectivity(device_index, args.N, devices) # fixed neighbor
                # neighbor = np.random.choice(np.arange(devices), args.N, p=Probabilities) # choose neighbor
                neighbor = np.random.choice(np.arange(devices), args.N)
                while device_index == neighbor:
                    neighbor = np.random.choice(np.arange(devices), args.N)

                if not train_start:
                    print(
                        "Episode {}, frame count {}, running reward: {:.2f}, loss {:.2f}"
                        .format(episode_count, frame_count, running_reward,
                                loss.numpy()))
                    if federated and not training_signal:
                        print(
                            "Neighbor {} for device {} at episode {} and frame_count {}"
                            .format(neighbor, device_index, episode_count,
                                    frame_count))

                        eps_c = 1 / (args.N + 1)
                        # apply consensus for model parameter
                        model.set_weights(
                            cfa_consensus.federated_weights_computing(
                                neighbor, args.N, frame_count, eps_c,
                                update_consensus))
                        if cfa_consensus.getTrainingStatusFromNeightbor():
                            training_signal = True
                    elif target_server:
                        stop_aggregation = False
                        while not os.path.isfile(global_target_model):
                            # implementing consensus
                            print("waiting")
                            pause(1)
                        try:
                            model_target_global = np.load(global_target_model,
                                                          allow_pickle=True)
                        except:
                            pause(5)
                            print("retrying opening target model")
                            try:
                                model_target_global = np.load(
                                    global_target_model, allow_pickle=True)
                            except:
                                print("halting aggregation")
                                stop_aggregation = True

                        if not stop_aggregation:
                            print("Device {} at episode {}".format(
                                device_index, episode_count))
                            model.set_weights(model_target_global.tolist())
                else:
                    print("Warm up")
                    train_start = False

                # model.save(checkpointpath1, include_optimizer=True, save_format='h5')
                np.savez(outfile,
                         frame_count=frame_count,
                         episode_reward_history=episode_reward_history,
                         running_reward=running_reward,
                         episode_count=episode_count,
                         epsilon=epsilon,
                         training_end=training_end)
                np.save(outfile_models, model_weights)
                del model_weights

            if frame_count % update_target_network == 0:
                # update the the target network with new weights
                model.save(checkpointpath1,
                           include_optimizer=True,
                           save_format='h5')
                model_target.save(checkpointpath2,
                                  include_optimizer=True,
                                  save_format='h5')
                # np.savez_compressed(outfile, frame_count=frame_count, episode_reward_history=episode_reward_history,
                # running_reward=running_reward, episode_count=episode_count, epsilon=epsilon, training_end=training_end)
                stop_aggregation = False

                model_target.set_weights(model.get_weights())

                # model_loaded = models.load_model('model.h5')
                # model_loaded._make_predict_function() # unclear
                # Log details

            # Limit the state and reward history
            if len(rewards_history) > max_memory_length:
                del rewards_history[:1]
                del state_history[:1]
                del state_next_history[:1]
                del action_history[:1]
                del done_history[:1]

            if done:
                # print("initializing")
                break

        # validation tool for device 'device_index'
        trajectory = np.zeros(number_positions, dtype=int)
        [obs, reward, done] = robot_trajectory_validation.initialize(
        )  # initialize in position 0 (entrance)
        state = preprocess_observation(np.squeeze(obs))
        episode_reward = reward
        tr_count = 0
        # print(training_signal)
        for timestep_v in range(
                1, number_positions
        ):  # validate overall the full position set (number of positions)
            trajectory[tr_count] = robot_trajectory_validation.getPosition()
            tr_count += 1
            # wait for epsilon_random_frames before validating
            if epsilon_min_validation > np.random.rand(1)[0]:
                # Take random action
                action = np.random.choice(n_outputs)
            else:
                # Predict action Q-values
                # From environment state
                state_tensor = tf.convert_to_tensor(state)
                state_tensor = tf.expand_dims(state_tensor, 0)
                action_probs = model(state_tensor, training=False)
                # Take best action
                action = tf.argmax(action_probs[0]).numpy()
            dev = 0
            [obs, reward,
             done] = robot_trajectory_validation.implement(action, dev)
            state_next = preprocess_observation(np.squeeze(obs))
            state_next = np.array(state_next)
            episode_reward += reward
            state = state_next
            if done:
                print("found an exit with reward {}".format(reward))
                break
        trajectory[tr_count] = robot_trajectory_validation.getPosition()

        # Update running reward to check condition for solving
        episode_reward_history.append(episode_reward)
        if len(episode_reward_history) > max_episodes:  # check memory
            del episode_reward_history[:1]
        # mean reward for last 10 episodes (change depending on application)
        running_reward = np.mean(episode_reward_history[-10:])

        episode_count += 1

        if running_reward > target_reward:  # Condition to consider the task solved
            print(
                "Solved for device {} at episode {} with running reward {:.2f} !"
                .format(device_index, episode_count, running_reward))
            print(trajectory)
            print("Reward {:.2f} for trajectory".format(episode_reward))

            training_end = True
            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            # model_target.save(checkpointpath2, include_optimizer=True, save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     episode_reward_history=episode_reward_history,
                     running_reward=running_reward,
                     episode_count=episode_count,
                     epsilon=epsilon,
                     training_end=training_end)
            np.save(outfile_models, model_weights)
            dict_1 = {"episode_reward_history": episode_reward_history}
            if federated:
                sio.savemat(n_file_cfa, dict_1)
                model.save(n_file_cfa_h5,
                           include_optimizer=True,
                           save_format='h5')
            elif target_server:
                sio.savemat(n_file_fa, dict_1)
                sio.savemat("FA_device_{}.mat".format(device_index), dict_1)
                model.save(n_file_fa_h5,
                           include_optimizer=True,
                           save_format='h5')
            elif args.centralized == 1:
                sio.savemat(n_file_cl, dict_1)
                model.save(n_file_cl_h5,
                           include_optimizer=True,
                           save_format='h5')
            else:
                sio.savemat(n_file_isolated, dict_1)
                model.save(n_file_isolated_h5,
                           include_optimizer=True,
                           save_format='h5')
            break

        if episode_count > max_episodes:  # stop simulation
            print("Unsolved for device {} at episode {}!".format(
                device_index, episode_count))
            training_end = True
            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            # model_target.save(checkpointpath2, include_optimizer=True, save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     episode_reward_history=episode_reward_history,
                     running_reward=running_reward,
                     episode_count=episode_count,
                     epsilon=epsilon,
                     training_end=training_end)
            np.save(outfile_models, model_weights)
            dict_1 = {"episode_reward_history": episode_reward_history}
            if federated:
                sio.savemat(n_file_cfa, dict_1)
                model.save(n_file_cfa_h5,
                           include_optimizer=True,
                           save_format='h5')
            elif target_server:
                sio.savemat(n_file_fa, dict_1)
                sio.savemat("FA_device_{}.mat".format(device_index), dict_1)
                model.save(n_file_fa_h5,
                           include_optimizer=True,
                           save_format='h5')
            elif args.centralized == 1:
                sio.savemat(n_file_cl, dict_1)
                model.save(n_file_cl_h5,
                           include_optimizer=True,
                           save_format='h5')
            else:
                sio.savemat(n_file_isolated, dict_1)
                model.save(n_file_isolated_h5,
                           include_optimizer=True,
                           save_format='h5')
            break