예제 #1
0
def processData(device_index, start_samples, samples, federated, full_data_size, number_of_batches, parameter_server, sample_distribution):
    pause(5) # PS server (if any) starts first
    checkpointpath1 = 'results/model{}.h5'.format(device_index)
    outfile = 'results/dump_train_variables{}.npz'.format(device_index)
    outfile_models = 'results/dump_train_model{}.npy'.format(device_index)
    global_model = 'results/model_global.npy'
    global_epoch = 'results/epoch_global.npy'

    np.random.seed(1)
    tf.random.set_seed(1)  # common initialization

    learning_rate = args.mu
    learning_rate_local = learning_rate

    B = np.ones((devices, devices)) - tf.one_hot(np.arange(devices), devices)
    Probabilities = B[device_index, :]/(devices - 1)
    training_signal = False

    # check for backup variables on start
    if os.path.isfile(checkpointpath1):
        train_start = False

        # backup the model and the model target
        model = models.load_model(checkpointpath1)
        data_history = []
        label_history = []
        local_model_parameters = np.load(outfile_models, allow_pickle=True)
        model.set_weights(local_model_parameters.tolist())

        dump_vars = np.load(outfile, allow_pickle=True)
        frame_count = dump_vars['frame_count']
        epoch_loss_history = dump_vars['epoch_loss_history'].tolist()
        running_loss = np.mean(epoch_loss_history[-5:])
        epoch_count = dump_vars['epoch_count']
    else:
        train_start = True
        model = create_q_model()
        data_history = []
        label_history = []
        frame_count = 0
        # Experience replay buffers
        epoch_loss_history = []
        epoch_count = 0
        running_loss = math.inf

    if parameter_server:
        epoch_global = 0

    training_end = False

    a = model.get_weights()
    # set an arbitrary optimizer, here Adam is used
    optimizer = keras.optimizers.Adam(learning_rate=args.mu, clipnorm=1.0)
    # create a data object (here radar data)
    # start = time.time()
    data_handle = MnistData(device_index, start_samples, samples, full_data_size, args.random_data_distribution)
    # end = time.time()
    # time_count = (end - start)
    # print(Training time"time_count)
    # create a consensus object
    cfa_consensus = CFA_process(devices, device_index, args.N)

    while True:  # Run until solved
        # collect 1 batch
        frame_count += 1
        obs, labels = data_handle.getTrainingData(batch_size)
        data_batch = preprocess_observation(obs, batch_size)

        # Save data and labels in the current learning session
        data_history.append(data_batch)
        label_history.append(labels)


        if frame_count % number_of_batches == 0:
            if not parameter_server:
                epoch_count += 1
            # check scheduling for federated
            if federated:
                if epoch_count == 1 or scheduling_tx[device_index, epoch_count] == 1:
                    training_signal = False
                else:
                    # stop all computing, just save the previous model
                    training_signal = True
                    model_weights = np.asarray(model.get_weights())
                    model.save(checkpointpath1, include_optimizer=True, save_format='h5')
                    np.savez(outfile, frame_count=frame_count, epoch_loss_history=epoch_loss_history,
                         training_end=training_end, epoch_count=epoch_count, loss=running_loss)
                    np.save(outfile_models, model_weights)
            # check scheduling for parameter server
            if parameter_server:
                while not os.path.isfile(global_epoch):
                    # implementing consensus
                    print("waiting")
                    pause(1)
                try:
                    epoch_global = np.load(global_epoch, allow_pickle=True)
                except:
                    pause(5)
                    print("retrying opening global epoch counter")
                    try:
                        epoch_global = np.load(global_epoch, allow_pickle=True)
                    except:
                        print("failed reading global epoch")

                if epoch_global == 0:
                    training_signal = False

                elif scheduling_tx[device_index, epoch_global] == 1:
                    if epoch_global > epoch_count:
                        epoch_count = epoch_global
                        training_signal = False
                    else:
                        training_signal = True
                else:
                    # stop all computing, just save the previous model
                    training_signal = True

                # always refresh the local model using the PS one
                stop_aggregation = False
                while not os.path.isfile(global_model):
                    # implementing consensus
                    print("waiting")
                    pause(1)
                try:
                    model_global = np.load(global_model, allow_pickle=True)
                except:
                    pause(5)
                    print("retrying opening global model")
                    try:
                        model_global = np.load(global_model, allow_pickle=True)
                    except:
                        print("halting aggregation")
                        stop_aggregation = True

                if not stop_aggregation:
                    model.set_weights(model_global.tolist())

                if training_signal:
                    model_weights = np.asarray(model.get_weights())
                    model.save(checkpointpath1, include_optimizer=True, save_format='h5')
                    np.savez(outfile, frame_count=frame_count, epoch_loss_history=epoch_loss_history,
                             training_end=training_end, epoch_count=epoch_count, loss=running_loss)
                    np.save(outfile_models, model_weights)
            # check schedulting for parameter server

        # Local learning update every "number of batches" batches
        time_count = 0
        if frame_count % number_of_batches == 0 and not training_signal:
            # run local batches
            for i in range(number_of_batches):
                start = time.time()
                data_sample = np.array(data_history[i])
                label_sample = np.array(label_history[i])

                # Create a mask to calculate loss
                masks = tf.one_hot(label_sample, n_outputs)

                with tf.GradientTape() as tape:
                    # Train the model on data samples
                    classes = model(data_sample)
                    # Apply the masks
                    class_v = tf.reduce_sum(tf.multiply(classes, masks), axis=1)
                    # Calculate loss
                    loss = loss_function(label_sample, class_v)

                # Backpropagation
                grads = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(grads, model.trainable_variables))
                end = time.time()
                time_count = time_count + (end-start)/number_of_batches
            if not parameter_server and not federated:
                print('Average batch training time {:.2f}'.format(time_count))
            del data_history
            del label_history
            data_history = []
            label_history = []

            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1, include_optimizer=True, save_format='h5')
            np.savez(outfile, frame_count=frame_count, epoch_loss_history=epoch_loss_history,
                     training_end=training_end, epoch_count=epoch_count, loss=running_loss)
            np.save(outfile_models, model_weights)


            #  Consensus round
            # update local model
            cfa_consensus.update_local_model(model_weights)
            # neighbor = cfa_consensus.get_connectivity(device_index, args.N, devices) # fixed neighbor

            if not train_start:
                if federated and not training_signal:
                    eps_c = 1 / (args.N + 1)
                    # apply consensus for model parameter
                    # neighbor = np.random.choice(np.arange(devices), args.N, p=Probabilities, replace=False) # choose neighbor
                    neighbor = np.random.choice(indexes_tx[:, epoch_count - 1], args.N, replace=False) # choose neighbor
                    while neighbor == device_index:
                        neighbor = np.random.choice(indexes_tx[:, epoch_count - 1], args.N,
                                                    replace=False)  # choose neighbor
                    print("Consensus from neighbor {} for device {}, local loss {:.2f}".format(neighbor, device_index,
                                                                                               loss.numpy()))

                    model.set_weights(cfa_consensus.federated_weights_computing(neighbor, args.N, epoch_count, eps_c, max_lag))
                    if cfa_consensus.getTrainingStatusFromNeightbor():
                        # a neighbor completed the training, with loss < target, transfer learning is thus applied (the device will copy and reuse the same model)
                        training_signal = True # stop local learning, just do validation
            else:
                print("Warm up")
                train_start = False

            # check if parameter server is enabled
            # stop_aggregation = False

            # if parameter_server:
            #     # pause(refresh_server)
            #     while not os.path.isfile(global_model):
            #         # implementing consensus
            #         print("waiting")
            #         pause(1)
            #     try:
            #         model_global = np.load(global_model, allow_pickle=True)
            #     except:
            #         pause(5)
            #         print("retrying opening global model")
            #         try:
            #             model_global = np.load(global_model, allow_pickle=True)
            #         except:
            #             print("halting aggregation")
            #             stop_aggregation = True
            #
            #     if not stop_aggregation:
            #         # print("updating from global model inside the parmeter server")
            #         for k in range(cfa_consensus.layers):
            #             # model_weights[k] = model_weights[k]+ 0.5*(model_global[k]-model_weights[k])
            #             model_weights[k] = model_global[k]
            #         model.set_weights(model_weights.tolist())
            #
            #     while not os.path.isfile(global_epoch):
            #         # implementing consensus
            #         print("waiting")
            #         pause(1)
            #     try:
            #         epoch_global = np.load(global_epoch, allow_pickle=True)
            #     except:
            #         pause(5)
            #         print("retrying opening global epoch counter")
            #         try:
            #             epoch_global = np.load(global_epoch, allow_pickle=True)
            #         except:
            #             print("halting aggregation")

            del model_weights


        #start = time.time()
        # validation tool for device 'device_index'
        if epoch_count > validation_start and frame_count % number_of_batches == 0:
            avg_cost = 0.
            for i in range(number_of_batches_for_validation):
                obs_valid, labels_valid = data_handle.getTestData(batch_size, i)
                # obs_valid, labels_valid = data_handle.getRandomTestData(batch_size)
                data_valid = preprocess_observation(np.squeeze(obs_valid), batch_size)
                data_sample = np.array(data_valid)
                label_sample = np.array(labels_valid)
                # Create a mask to calculate loss
                masks = tf.one_hot(label_sample, n_outputs)
                classes = model(data_sample)
                # Apply the masks
                class_v = tf.reduce_sum(tf.multiply(classes, masks), axis=1)
                # Calculate loss
                loss = loss_function(label_sample, class_v)
                avg_cost += loss / number_of_batches_for_validation  # Training loss
            epoch_loss_history.append(avg_cost)
            print("Device {} epoch count {}, validation loss {:.2f}".format(device_index, epoch_count,
                                                                                 avg_cost))
            # mean loss for last 5 epochs
            running_loss = np.mean(epoch_loss_history[-1:])
        #end = time.time()
        #time_count = (end - start)
        #print(time_count)

        if running_loss < target_loss:  # Condition to consider the task solved
            print("Solved for device {} at epoch {} with average loss {:.2f} !".format(device_index, epoch_count, running_loss))
            training_end = True
            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1, include_optimizer=True, save_format='h5')
            # model_target.save(checkpointpath2, include_optimizer=True, save_format='h5')
            np.savez(outfile, frame_count=frame_count, epoch_loss_history=epoch_loss_history,
                     training_end=training_end, epoch_count=epoch_count, loss=running_loss)
            np.save(outfile_models, model_weights)

            if federated:
                dict_1 = {"epoch_loss_history": epoch_loss_history, "federated": federated,
                      "parameter_server": parameter_server, "devices": devices, "neighbors": args.N,
                      "active_devices": args.Ka_consensus,
                      "batches": number_of_batches, "batch_size": batch_size, "samples": samples, "noniid": args.noniid_assignment, "data_distribution": args.random_data_distribution}
            elif parameter_server:
                dict_1 = {"epoch_loss_history": epoch_loss_history, "federated": federated,
                          "parameter_server": parameter_server, "devices": devices,
                          "active_devices": active_devices_per_round,
                          "batches": number_of_batches, "batch_size": batch_size, "samples": samples,
                          "noniid": args.noniid_assignment, "data_distribution": args.random_data_distribution}
            else:
                dict_1 = {"epoch_loss_history": epoch_loss_history, "federated": federated,
                          "parameter_server": parameter_server, "devices": devices,
                          "batches": number_of_batches, "batch_size": batch_size, "samples": samples,
                          "noniid": args.noniid_assignment, "data_distribution": args.random_data_distribution}

            if federated:
                sio.savemat(
                    "results/matlab/CFA_device_{}_samples_{}_devices_{}_active_{}_neighbors_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat".format(
                        device_index, samples, devices, args.Ka_consensus, args.N, number_of_batches, batch_size, args.noniid_assignment, args.run, args.random_data_distribution), dict_1)
                sio.savemat(
                    "CFA_device_{}_samples_{}_devices_{}_neighbors_{}_batches_{}_size{}.mat".format(
                        device_index, samples, devices, args.N, number_of_batches, batch_size), dict_1)
            elif parameter_server:
                sio.savemat(
                    "results/matlab/FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat".format(
                        device_index, samples, devices, active_devices_per_round, number_of_batches, batch_size, args.noniid_assignment,args.run, args.random_data_distribution), dict_1)
                sio.savemat(
                    "FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}.mat".format(
                        device_index, samples, devices, active_devices_per_round, number_of_batches, batch_size), dict_1)
            else: # CL
                sio.savemat(
                    "results/matlab/CL_samples_{}_devices_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat".format(samples, devices, number_of_batches, batch_size,
                                                                                                                         args.noniid_assignment, args.run, args.random_data_distribution), dict_1)
            break

        if epoch_count > max_epochs:  # stop simulation
            print("Unsolved for device {} at epoch {}!".format(device_index, epoch_count))
            training_end = True
            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1, include_optimizer=True, save_format='h5')
            # model_target.save(checkpointpath2, include_optimizer=True, save_format='h5')
            np.savez(outfile, frame_count=frame_count, epoch_loss_history=epoch_loss_history,
                     training_end=training_end, epoch_count=epoch_count, loss=running_loss)
            np.save(outfile_models, model_weights)

            if federated:
                dict_1 = {"epoch_loss_history": epoch_loss_history, "federated": federated,
                          "parameter_server": parameter_server, "devices": devices, "neighbors": args.N,
                          "active_devices": args.Ka_consensus,
                          "batches": number_of_batches, "batch_size": batch_size, "samples": samples,
                          "noniid": args.noniid_assignment, "data_distribution": args.random_data_distribution}
            elif parameter_server:
                dict_1 = {"epoch_loss_history": epoch_loss_history, "federated": federated,
                          "parameter_server": parameter_server, "devices": devices,
                          "active_devices": active_devices_per_round,
                          "batches": number_of_batches, "batch_size": batch_size, "samples": samples,
                          "noniid": args.noniid_assignment, "data_distribution": args.random_data_distribution}
            else:
                dict_1 = {"epoch_loss_history": epoch_loss_history, "federated": federated,
                          "parameter_server": parameter_server, "devices": devices,
                          "batches": number_of_batches, "batch_size": batch_size, "samples": samples,
                          "noniid": args.noniid_assignment, "data_distribution": args.random_data_distribution}

            if federated:
                sio.savemat(
                    "results/matlab/CFA_device_{}_samples_{}_devices_{}_active_{}_neighbors_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat".format(
                        device_index, samples, devices, args.Ka_consensus, args.N, number_of_batches, batch_size,
                        args.noniid_assignment, args.run, args.random_data_distribution), dict_1)
                sio.savemat(
                    "CFA_device_{}_samples_{}_devices_{}_neighbors_{}_batches_{}_size{}.mat".format(
                        device_index, samples, devices, args.N, number_of_batches, batch_size), dict_1)
            elif parameter_server:
                sio.savemat(
                    "results/matlab/FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat".format(
                        device_index, samples, devices, active_devices_per_round, number_of_batches, batch_size,
                        args.noniid_assignment, args.run, args.random_data_distribution), dict_1)
                sio.savemat(
                    "FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}.mat".format(
                        device_index, samples, devices, active_devices_per_round, number_of_batches, batch_size),
                    dict_1)
            else:  # CL
                sio.savemat(
                    "results/matlab/CL_samples_{}_devices_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat".format(
                        samples, devices, number_of_batches, batch_size,
                        args.noniid_assignment, args.run, args.random_data_distribution), dict_1)
            break
예제 #2
0
def processData(device_index, start_samples, samples, federated,
                full_data_size, number_of_batches, parameter_server,
                sample_distribution):
    pause(5)  # PS server (if any) starts first
    checkpointpath1 = 'results/model{}.h5'.format(device_index)
    outfile = 'results/dump_train_variables{}.npz'.format(device_index)
    outfile_models = 'results/dump_train_model{}.npy'.format(device_index)
    outfile_models_grad = 'results/dump_train_grad{}.npy'.format(device_index)
    global_model = 'results/model_global.npy'
    global_epoch = 'results/epoch_global.npy'

    #np.random.seed(1)
    #tf.random.set_seed(1)  # common initialization

    learning_rate = args.mu
    learning_rate_local = learning_rate

    B = np.ones((devices, devices)) - tf.one_hot(np.arange(devices), devices)
    Probabilities = B[device_index, :] / (devices - 1)
    training_signal = False

    # check for backup variables on start
    if os.path.isfile(checkpointpath1):
        train_start = False

        # backup the model and the model target
        model = models.load_model(checkpointpath1)
        model_transmitted = create_q_model()
        data_history = []
        label_history = []
        local_model_parameters = np.load(outfile_models, allow_pickle=True)
        model.set_weights(local_model_parameters.tolist())

        dump_vars = np.load(outfile, allow_pickle=True)
        frame_count = dump_vars['frame_count']
        epoch_loss_history = dump_vars['epoch_loss_history'].tolist()
        running_loss = np.mean(epoch_loss_history[-5:])
        epoch_count = dump_vars['epoch_count']
    else:
        train_start = True
        model = create_q_model()
        model_transmitted = create_q_model()
        data_history = []
        label_history = []
        frame_count = 0
        # Experience replay buffers
        epoch_loss_history = []
        epoch_count = 0
        running_loss = math.inf

    if parameter_server:
        epoch_global = 0

    training_end = False

    #a = model.get_weights()
    # set an arbitrary optimizer, here Adam is used
    optimizer = keras.optimizers.Adam(learning_rate=args.mu, clipnorm=1.0)
    #optimizer2 = keras.optimizers.SGD(learning_rate=args.mu2)
    optimizer2 = keras.optimizers.Adam(learning_rate=args.mu2, clipnorm=1.0)
    # create a data object (here radar data)
    # start = time.time()
    if args.noniid_assignment == 1:
        data_handle = MnistData_task(device_index, start_samples, samples,
                                     full_data_size,
                                     args.random_data_distribution)
    else:
        data_handle = MnistData(device_index, start_samples, samples,
                                full_data_size, args.random_data_distribution)

    # end = time.time()
    # time_count = (end - start)
    # print(Training time"time_count)
    # create a consensus object
    cfa_consensus = CFA_process(devices, device_index, args.N)

    while True:  # Run until solved
        # collect 1 batch
        frame_count += 1
        obs, labels = data_handle.getTrainingData(batch_size)
        data_batch = preprocess_observation(obs, batch_size)

        # Save data and labels in the current learning session
        data_history.append(data_batch)
        label_history.append(labels)

        if frame_count % number_of_batches == 0:
            if not parameter_server:
                epoch_count += 1
            # check scheduling for federated
            if federated:
                if epoch_count == 1 or scheduling_tx[device_index,
                                                     epoch_count] == 1:
                    training_signal = False
                else:
                    # stop all computing, just save the previous model
                    training_signal = True
                    model_weights = np.asarray(model.get_weights())
                    model.save(checkpointpath1,
                               include_optimizer=True,
                               save_format='h5')
                    np.savez(outfile,
                             frame_count=frame_count,
                             epoch_loss_history=epoch_loss_history,
                             training_end=training_end,
                             epoch_count=epoch_count,
                             loss=running_loss)
                    np.save(outfile_models, model_weights)

        # Local learning update every "number of batches" batches
        # time_count = 0
        if frame_count % number_of_batches == 0 and not training_signal:
            # run local batches
            for i in range(number_of_batches):
                start = time.time()
                data_sample = np.array(data_history[i])
                label_sample = np.array(label_history[i])

                # Create a mask to calculate loss
                masks = tf.one_hot(label_sample, n_outputs)

                with tf.GradientTape() as tape:
                    # Train the model on data samples
                    classes = model(data_sample)
                    # Apply the masks
                    class_v = tf.reduce_sum(tf.multiply(classes, masks),
                                            axis=1)
                    # Calculate loss
                    loss = loss_function(label_sample, class_v)

                # Backpropagation
                grads = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(grads,
                                              model.trainable_variables))
                #end = time.time()
                #time_count = time_count + (end-start)/number_of_batches

            del data_history
            del label_history
            data_history = []
            label_history = []

            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     epoch_loss_history=epoch_loss_history,
                     training_end=training_end,
                     epoch_count=epoch_count,
                     loss=running_loss)
            np.save(outfile_models, model_weights)
            cfa_consensus.update_local_model(model_weights)
            grads_v = []
            for d in range(len(grads)):
                grads_v.append(grads[d].numpy())
            grads_v = np.asarray(grads_v)
            cfa_consensus.update_local_gradient(grads_v)

            # compute gradients for selected neighbors in get_tx_connectvity, obtain a new test observation from local database
            obs_t, labels_t = data_handle.getTrainingData(batch_size)
            data_batch_t = preprocess_observation(obs_t, batch_size)
            masks_t = tf.one_hot(labels_t, n_outputs)
            gradient_neighbor = cfa_consensus.get_tx_connectivity(
                device_index, args.N, devices)
            outfile_n = 'results/dump_train_variables{}.npz'.format(
                gradient_neighbor)
            outfile_models_n = 'results/dump_train_model{}.npy'.format(
                gradient_neighbor)
            neighbor_model_for_gradient, success = cfa_consensus.get_neighbor_weights(
                epoch_count, outfile_n, outfile_models_n, epoch=0, max_lag=1)
            if success:
                model_transmitted.set_weights(
                    neighbor_model_for_gradient.tolist())
            else:
                print("failed retrieving the model for gradient computation")

            with tf.GradientTape() as tape2:
                # Train the model on data samples
                classes = model_transmitted(data_batch_t)
                # Apply the masks
                class_v = tf.reduce_sum(tf.multiply(classes, masks_t), axis=1)
                # Calculate loss
                loss = loss_function(labels_t, class_v)

            # getting and save neighbor gradients
            grads_t = tape2.gradient(loss,
                                     model_transmitted.trainable_variables)
            grads_v = []
            for d in range(len(grads_t)):
                grads_v.append(grads_t[d].numpy())
            grads_v = np.asarray(grads_v)
            np.save(outfile_models_grad, grads_v)

            np.random.seed(1)
            tf.random.set_seed(1)  # common initialization
            if not train_start:
                if federated and not training_signal:
                    eps_c = args.eps
                    # apply consensus for model parameter
                    neighbor = cfa_consensus.get_connectivity(
                        device_index, args.N, devices)  # fixed neighbor
                    #if args.gradients == 0 or running_loss < 0.5:
                    if args.gradients == 0:
                        # random selection of neighor
                        # neighbor = np.random.choice(indexes_tx[:, epoch_count - 1], args.N, replace=False) # choose neighbor
                        # while neighbor == device_index:
                        #     neighbor = np.random.choice(indexes_tx[:, epoch_count - 1], args.N,
                        #                             replace=False)  # choose neighbor
                        print(
                            "Consensus from neighbor {} for device {}, local loss {:.2f}"
                            .format(neighbor, device_index, loss.numpy()))
                        model.set_weights(
                            cfa_consensus.federated_weights_computing(
                                neighbor, args.N, epoch_count, eps_c, max_lag))
                        if cfa_consensus.getTrainingStatusFromNeightbor():
                            training_signal = True  # stop local learning, just do validation
                    else:
                        # compute gradients as usual

                        print(
                            "Consensus from neighbor {} for device {}, local loss {:.2f}"
                            .format(neighbor, device_index, loss.numpy()))
                        print("Applying gradient updates...")
                        # model.set_weights(cfa_consensus.federated_weights_computing(neighbor, args.N, epoch_count, eps_c, max_lag))
                        model_averaging = cfa_consensus.federated_weights_computing(
                            neighbor, args.N, epoch_count, eps_c, max_lag)
                        model.set_weights(model_averaging)
                        if cfa_consensus.getTrainingStatusFromNeightbor():
                            # model.set_weights(model_averaging)
                            training_signal = True  # stop local learning, just do validation
                        else:
                            grads = cfa_consensus.federated_grads_computing(
                                neighbor, args.N, epoch_count, args.eps_grads,
                                max_lag)
                            optimizer2.apply_gradients(
                                zip(grads, model.trainable_variables))
            else:
                print("Warm up")
                train_start = False

            del model_weights

        #start = time.time()
        # validation tool for device 'device_index'
        if epoch_count > validation_start and frame_count % number_of_batches == 0:
            avg_cost = 0.
            for i in range(number_of_batches_for_validation):
                obs_valid, labels_valid = data_handle.getTestData(
                    batch_size, i)
                # obs_valid, labels_valid = data_handle.getRandomTestData(batch_size)
                data_valid = preprocess_observation(np.squeeze(obs_valid),
                                                    batch_size)
                data_sample = np.array(data_valid)
                label_sample = np.array(labels_valid)
                # Create a mask to calculate loss
                masks = tf.one_hot(label_sample, n_outputs)
                classes = model(data_sample)
                # Apply the masks
                class_v = tf.reduce_sum(tf.multiply(classes, masks), axis=1)
                # Calculate loss
                loss = loss_function(label_sample, class_v)
                avg_cost += loss / number_of_batches_for_validation  # Training loss
            epoch_loss_history.append(avg_cost)
            print("Device {} epoch count {}, validation loss {:.2f}".format(
                device_index, epoch_count, avg_cost))
            # mean loss for last 5 epochs
            running_loss = np.mean(epoch_loss_history[-1:])
        #end = time.time()
        #time_count = (end - start)
        #print(time_count)

        if running_loss < target_loss:  # Condition to consider the task solved
            print(
                "Solved for device {} at epoch {} with average loss {:.2f} !".
                format(device_index, epoch_count, running_loss))
            training_end = True
            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            # model_target.save(checkpointpath2, include_optimizer=True, save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     epoch_loss_history=epoch_loss_history,
                     training_end=training_end,
                     epoch_count=epoch_count,
                     loss=running_loss)
            np.save(outfile_models, model_weights)

            if federated:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "neighbors": args.N,
                    "active_devices": args.Ka_consensus,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            elif parameter_server:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "active_devices": active_devices_per_round,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            else:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }

            if federated:
                sio.savemat(
                    "results/matlab/CFA_device_{}_samples_{}_devices_{}_active_{}_neighbors_{}_batches_{}_size{}_noniid{}_run{}_distribution{}_gradients{}.mat"
                    .format(device_index, samples, devices, args.Ka_consensus,
                            args.N, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution, args.gradients),
                    dict_1)
                sio.savemat(
                    "CFA_device_{}_samples_{}_devices_{}_neighbors_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices, args.N,
                            number_of_batches, batch_size), dict_1)
            elif parameter_server:
                sio.savemat(
                    "results/matlab/FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size, args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
                sio.savemat(
                    "FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size), dict_1)
            else:  # CL
                sio.savemat(
                    "results/matlab/CL_samples_{}_devices_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(samples, devices, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
            break

        if epoch_count > max_epochs:  # stop simulation
            print("Unsolved for device {} at epoch {}!".format(
                device_index, epoch_count))
            training_end = True
            model_weights = np.asarray(model.get_weights())
            model.save(checkpointpath1,
                       include_optimizer=True,
                       save_format='h5')
            # model_target.save(checkpointpath2, include_optimizer=True, save_format='h5')
            np.savez(outfile,
                     frame_count=frame_count,
                     epoch_loss_history=epoch_loss_history,
                     training_end=training_end,
                     epoch_count=epoch_count,
                     loss=running_loss)
            np.save(outfile_models, model_weights)

            if federated:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "neighbors": args.N,
                    "active_devices": args.Ka_consensus,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            elif parameter_server:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "active_devices": active_devices_per_round,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }
            else:
                dict_1 = {
                    "epoch_loss_history": epoch_loss_history,
                    "federated": federated,
                    "parameter_server": parameter_server,
                    "devices": devices,
                    "batches": number_of_batches,
                    "batch_size": batch_size,
                    "samples": samples,
                    "noniid": args.noniid_assignment,
                    "data_distribution": args.random_data_distribution
                }

            if federated:
                sio.savemat(
                    "results/matlab/CFA_device_{}_samples_{}_devices_{}_active_{}_neighbors_{}_batches_{}_size{}_noniid{}_run{}_distribution{}_gradients{}.mat"
                    .format(device_index, samples, devices, args.Ka_consensus,
                            args.N, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution, args.gradients),
                    dict_1)
                sio.savemat(
                    "CFA_device_{}_samples_{}_devices_{}_neighbors_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices, args.N,
                            number_of_batches, batch_size), dict_1)
            elif parameter_server:
                sio.savemat(
                    "results/matlab/FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size, args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
                sio.savemat(
                    "FA_device_{}_samples_{}_devices_{}_active_{}_batches_{}_size{}.mat"
                    .format(device_index, samples, devices,
                            active_devices_per_round, number_of_batches,
                            batch_size), dict_1)
            else:  # CL
                sio.savemat(
                    "results/matlab/CL_samples_{}_devices_{}_batches_{}_size{}_noniid{}_run{}_distribution{}.mat"
                    .format(samples, devices, number_of_batches, batch_size,
                            args.noniid_assignment, args.run,
                            args.random_data_distribution), dict_1)
            break