def processData(device_index, start_samples, samples, federated,
                full_data_size, number_of_batches, parameter_server,
                sample_distribution):
    pause(5)  # PS server (if any) starts first
    checkpointpath1 = 'results/model{}.h5'.format(device_index)
    outfile = 'results/dump_train_variables{}.npz'.format(device_index)
    outfile_models = 'results/dump_train_model{}.npy'.format(device_index)
    global_model = 'results/model_global.npy'
    global_epoch = 'results/epoch_global.npy'

    np.random.seed(1)
    tf.random.set_seed(1)  # common initialization

    learning_rate = args.mu
    learning_rate_local = learning_rate

    B = np.ones((devices, devices)) - tf.one_hot(np.arange(devices), devices)
    Probabilities = B[device_index, :] / (devices - 1)
    training_signal = False

    # check for backup variables on start
    if os.path.isfile(checkpointpath1):
        train_start = False

        # backup the model and the model target
        model = models.load_model(checkpointpath1)
        data_history = []
        label_history = []
        local_model_parameters = np.load(outfile_models, allow_pickle=True)
        model.set_weights(local_model_parameters.tolist())

        dump_vars = np.load(outfile, allow_pickle=True)
        frame_count = dump_vars['frame_count']
        epoch_loss_history = dump_vars['epoch_loss_history'].tolist()
        running_loss = np.mean(epoch_loss_history[-5:])
        epoch_count = dump_vars['epoch_count']
    else:
        train_start = True
        model = create_q_model()
        data_history = []
        label_history = []
        frame_count = 0
        # Experience replay buffers
        epoch_loss_history = []
        epoch_count = 0
        running_loss = math.inf

    if parameter_server:
        epoch_global = 0

    training_end = False

    a = model.get_weights()
    # set an arbitrary optimizer, here Adam is used
    optimizer = keras.optimizers.Adam(learning_rate=args.mu, clipnorm=1.0)
    # create a data object (here radar data)
    start = time.time()
    data_handle = RadarData(filepath, device_index, start_samples, samples,
                            full_data_size, args.random_data_distribution)
    end = time.time()
    time_count = (end - start)
    print(time_count)
    # create a consensus object
    cfa_consensus = CFA_process(devices, device_index, args.N)

    while True:  # Run until solved
        # collect 1 batch
        frame_count += 1
        obs, labels = data_handle.getTrainingData(batch_size)
        data_batch = preprocess_observation(obs, batch_size)

        # Save data and labels in the current learning session
        data_history.append(data_batch)
        label_history.append(labels)

        if frame_count % number_of_batches == 0:
            if not parameter_server:
                epoch_count += 1
            # check scheduling for federated
            if federated:
                if epoch_count == 1 or scheduling_tx[device_index,
                                                     epoch_count] == 1:
                    training_signal = False
                else:
                    # stop all computing, just save the previous model
                    training_signal = True
                    model_weights = np.asarray(model.get_weights())
                    model.save(checkpointpath1,
                               include_optimizer=True,
                               save_format='h5')
                    np.savez(outfile,
                             frame_count=frame_count,
                             epoch_loss_history=epoch_loss_history,
                             training_end=training_end,
                             epoch_count=epoch_count,
                             loss=running_loss)
                    np.save(outfile_models, model_weights)
            # check scheduling for parameter server
            if parameter_server:
                while not os.path.isfile(global_epoch):
                    # implementing consensus
                    print("waiting")
                    pause(1)
                try:
                    epoch_global = np.load(global_epoch, allow_pickle=True)
                except:
                    pause(5)
                    print("retrying opening global epoch counter")
                    try:
                        epoch_global = np.load(global_epoch, allow_pickle=True)
                    except:
                        print("failed reading global epoch")

                if epoch_global == 0:
                    training_signal = False

                elif scheduling_tx[device_index, epoch_global] == 1:
                    if epoch_global > epoch_count:
                        epoch_count = epoch_global
                        training_signal = False
                    else:
                        training_signal = True
                else:
                    # stop all computing, just save the previous model
                    training_signal = True

                # always refresh the local model using the PS one
                stop_aggregation = False
                while not os.path.isfile(global_model):
                    # implementing consensus
                    print("waiting")
                    pause(1)
                try:
                    model_global = np.load(global_model, allow_pickle=True)
                except:
                    pause(5)
                    print("retrying opening global model")
                    try:
                        model_global = np.load(global_model, allow_pickle=True)
                    except:
                        print("halting aggregation")
                        stop_aggregation = True

                if not stop_aggregation:
                    model.set_weights(model_global.tolist())

                if training_signal:
                    model_weights = np.asarray(model.get_weights())
                    model.save(checkpointpath1,
                               include_optimizer=True,
                               save_format='h5')
                    np.savez(outfile,
                             frame_count=frame_count,
                             epoch_loss_history=epoch_loss_history,
                             training_end=training_end,
                             epoch_count=epoch_count,
                             loss=running_loss)
                    np.save(outfile_models, model_weights)
            # check schedulting for parameter server

        # Local learning update every "number of batches" batches
        time_count = 0
        if frame_count % number_of_batches == 0 and not training_signal:
            # run local batches
            for i in range(number_of_batches):
                start = time.time()
                data_sample = np.array(data_history[i])
                label_sample = np.array(label_history[i])

                # Create a mask to calculate loss
                masks = tf.one_hot(label_sample, n_outputs)

                with tf.GradientTape() as tape:
                    # Train the model on data samples
                    classes = model(data_sample)
                    # Apply the masks
                    class_v = tf.reduce_sum(tf.multiply(classes, masks),
                                            axis=1)
                    # Calculate loss
                    loss = loss_function(label_sample, class_v)

                # Backpropagation
                grads = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))
                end = time.time()
                time_count = time_count + (end - start) / number_of_batches
            if not parameter_server and not federated:
                print('Average batch training time {:.2f}'.format(time_count))
            del data_history
            del label_history
            data_history = []
            label_history = []

            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     epoch_loss_history=epoch_loss_history,
                     training_end=training_end,
                     epoch_count=epoch_count,
                     loss=running_loss)
            np.save(outfile_models, model_weights)

            #  Consensus round
            # update local model
            cfa_consensus.update_local_model(model_weights)
            # neighbor = cfa_consensus.get_connectivity(device_index, args.N, devices) # fixed neighbor

            if not train_start:
                if federated and not training_signal:
                    eps_c = 1 / (args.N + 1)
                    # apply consensus for model parameter
                    # neighbor = np.random.choice(np.arange(devices), args.N, p=Probabilities, replace=False) # choose neighbor
                    neighbor = np.random.choice(
                        indexes_tx[:, epoch_count - 1], args.N,
                        replace=False)  # choose neighbor
                    while neighbor == device_index:
                        neighbor = np.random.choice(
                            indexes_tx[:, epoch_count - 1],
                            args.N,
                            replace=False)  # choose neighbor
                    print(
                        "Consensus from neighbor {} for device {}, local loss {:.2f}"
                        .format(neighbor, device_index, loss.numpy()))

                    model.set_weights(
                        cfa_consensus.federated_weights_computing(
                            neighbor, args.N, epoch_count, eps_c, max_lag))
                    if cfa_consensus.getTrainingStatusFromNeightbor():
                        # a neighbor completed the training, with loss < target, transfer learning is thus applied (the device will copy and reuse the same model)
                        training_signal = True  # stop local learning, just do validation
            else:
                print("Consensus warm up")
                train_start = False

            # check if parameter server is enabled
            # stop_aggregation = False

            # if parameter_server:
            #     # pause(refresh_server)
            #     while not os.path.isfile(global_model):
            #         # implementing consensus
            #         print("waiting")
            #         pause(1)
            #     try:
            #         model_global = np.load(global_model, allow_pickle=True)
            #     except:
            #         pause(5)
            #         print("retrying opening global model")
            #         try:
            #             model_global = np.load(global_model, allow_pickle=True)
            #         except:
            #             print("halting aggregation")
            #             stop_aggregation = True
            #
            #     if not stop_aggregation:
            #         # print("updating from global model inside the parmeter server")
            #         for k in range(cfa_consensus.layers):
            #             # model_weights[k] = model_weights[k]+ 0.5*(model_global[k]-model_weights[k])
            #             model_weights[k] = model_global[k]
            #         model.set_weights(model_weights.tolist())
            #
            #     while not os.path.isfile(global_epoch):
            #         # implementing consensus
            #         print("waiting")
            #         pause(1)
            #     try:
            #         epoch_global = np.load(global_epoch, allow_pickle=True)
            #     except:
            #         pause(5)
            #         print("retrying opening global epoch counter")
            #         try:
            #             epoch_global = np.load(global_epoch, allow_pickle=True)
            #         except:
            #             print("halting aggregation")

            del model_weights

        #start = time.time()
        # validation tool for device 'device_index'
        if epoch_count > validation_start and frame_count % number_of_batches == 0:
            avg_cost = 0.
            for i in range(number_of_batches_for_validation):
                obs_valid, labels_valid = data_handle.getTestData(
                    batch_size, i)
                # obs_valid, labels_valid = data_handle.getRandomTestData(batch_size)
                data_valid = preprocess_observation(np.squeeze(obs_valid),
                                                    batch_size)
                data_sample = np.array(data_valid)
                label_sample = np.array(labels_valid)
                # Create a mask to calculate loss
                masks = tf.one_hot(label_sample, n_outputs)
                classes = model(data_sample)
                # Apply the masks
                class_v = tf.reduce_sum(tf.multiply(classes, masks), axis=1)
                # Calculate loss
                loss = loss_function(label_sample, class_v)
                avg_cost += loss / number_of_batches_for_validation  # Training loss
            epoch_loss_history.append(avg_cost)
            print("Device {} epoch count {}, validation loss {:.2f}".format(
                device_index, epoch_count, avg_cost))
            # mean loss for last 5 epochs
            running_loss = np.mean(epoch_loss_history[-1:])
        #end = time.time()
        #time_count = (end - start)
        #print(time_count)

        if running_loss < target_loss:  # Condition to consider the task solved
            print(
                "Solved for device {} at epoch {} with average loss {:.2f} !".
                format(device_index, epoch_count, running_loss))
            training_end = True
            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            # model_target.save(checkpointpath2, include_optimizer=True, save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     epoch_loss_history=epoch_loss_history,
                     training_end=training_end,
                     epoch_count=epoch_count,
                     loss=running_loss)
            np.save(outfile_models, model_weights)

            if federated:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "neighbors": args.N,
                    "active_devices": args.Ka_consensus,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            elif parameter_server:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "active_devices": active_devices_per_round,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            else:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }

            if federated:
                sio.savemat(
                    "results/matlab/CFA_device_{}_samples_{}_devices_{}_active_{}_neighbors_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(device_index, samples, devices, args.Ka_consensus,
                            args.N, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
                sio.savemat(
                    "CFA_device_{}_samples_{}_devices_{}_neighbors_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices, args.N,
                            number_of_batches, batch_size), dict_1)
            elif parameter_server:
                sio.savemat(
                    "results/matlab/FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size, args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
                sio.savemat(
                    "FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size), dict_1)
            else:  # CL
                sio.savemat(
                    "results/matlab/CL_samples_{}_devices_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(samples, devices, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
            break

        if epoch_count > max_epochs:  # stop simulation
            print("Unsolved for device {} at epoch {}!".format(
                device_index, epoch_count))
            training_end = True
            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            # model_target.save(checkpointpath2, include_optimizer=True, save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     epoch_loss_history=epoch_loss_history,
                     training_end=training_end,
                     epoch_count=epoch_count,
                     loss=running_loss)
            np.save(outfile_models, model_weights)

            if federated:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "neighbors": args.N,
                    "active_devices": args.Ka_consensus,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            elif parameter_server:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "active_devices": active_devices_per_round,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            else:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }

            if federated:
                sio.savemat(
                    "results/matlab/CFA_device_{}_samples_{}_devices_{}_active_{}_neighbors_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(device_index, samples, devices, args.Ka_consensus,
                            args.N, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
                sio.savemat(
                    "CFA_device_{}_samples_{}_devices_{}_neighbors_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices, args.N,
                            number_of_batches, batch_size), dict_1)
            elif parameter_server:
                sio.savemat(
                    "results/matlab/FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size, args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
                sio.savemat(
                    "FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size), dict_1)
            else:  # CL
                sio.savemat(
                    "results/matlab/CL_samples_{}_devices_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(samples, devices, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
            break
        # Experience replay buffers
        epoch_loss_history = []
        epoch_count = 0
        running_loss = math.inf

    training_end = False
    # set an arbitrary optimizer, here Adam is used
    optimizer = keras.optimizers.Adam(learning_rate=args.mu, clipnorm=1.0)
    # create a data object (here radar data)
    # data_handle = MnistData(device_index, start_index, training_set_per_device, validation_train, 0)
    if args.noniid_assignment == 1:
        data_handle = RadarData_tasks(filepath, device_index, start_index,
                                      training_set_per_device,
                                      validation_train, 0)
    else:
        data_handle = RadarData(filepath, device_index, start_index,
                                training_set_per_device, validation_train, 0)
    # create a consensus object (only for consensus)
    # cfa_consensus = CFA_process(device_index, args.N)

    model_weights = np.asarray(model.get_weights())
    layers = model_weights.size

    del model_weights

    while epoch_count < local_rounds:
        frame_count += 1
        obs, labels = data_handle.getTrainingData(batch_size)
        data_batch = preprocess_observation(obs, batch_size)
        # Save data and labels in the current learning session
        data_history.append(data_batch)
        label_history.append(labels)