def call_federated_train(cls):
        """ call the model of each client for federated training """
        with timer("call federated train", logger):
            train_loss = []
            best_epoch = None
            best_loss = float('inf')

            for epoch in range(1, args.server_epoch + 1):
                with timer('train for epoch {}/{}'.format(epoch, args.server_epoch), logger):

                    pickled_epoch = Common.get_bytes_by_pickle_object_func(epoch)
                    pickled_server_model_params = Common.get_bytes_by_pickle_object_func(cls.server_model_params)

                    train_jobs = [gevent.spawn(RequestApi.request, method="POST", url=federated_train_url,
                                               params={"server_epoch": pickled_epoch,
                                                       "server_model_params": pickled_server_model_params},
                                               custom_headers={"Content-Type": "multipart/form-data"}) for
                                  federated_train_url in cls.federated_train_urls]
                    gevent.joinall(train_jobs, timeout=args.timeout)

                    avg_loss = 0.0
                    client_weight_lst = []

                    for idx, train_job in enumerate(train_jobs):
                        returned_client_model_params = Common.get_object_by_pickle_bytes_func(train_job.value['data'])[
                            "client_model_params"]
                        returned_epo_avg_loss = Common.get_object_by_pickle_bytes_func(train_job.value['data'])[
                            "epo_avg_loss"]

                        # update the average training loss of all clients for the epoch
                        avg_loss += (returned_epo_avg_loss - avg_loss) / (idx + 1)

                        client_weight_lst.append(returned_client_model_params)

                    for key in client_weight_lst[-1].keys():
                        client_weight_lst[-1][key] = cls.client_ratio_lst[-1] * client_weight_lst[-1][key]
                        for idx in range(0, len(client_weight_lst) - 1):
                            client_weight_lst[-1][key] += cls.client_ratio_lst[idx] * client_weight_lst[idx][key]

                    cls.server_model_params = client_weight_lst[-1]

                    logger.info('epoch {:3d}, average loss {:.3f}'.format(epoch, avg_loss))
                    train_loss.append(avg_loss)

                    # save the model, loss and epoch with the smallest training average loss for all the epochs
                    if avg_loss < best_loss:
                        best_loss = avg_loss
                        best_epoch = epoch
                        cls.best_model_params = cls.server_model_params

            logger.info("best train loss: {}".format(best_loss))
            logger.info("best epoch: {}".format(best_epoch))
            pickle.dump(cls.best_model_params, open(args.model_file_path, "wb"))
示例#2
0
    def detect(cls, detect_model_params: object):
        """ get the test results and save the json file of the federated averaging model on the test set """
        with timer("detect for user id {}".format(args.user_id), logger):
            model = cls.get_model(model_params=detect_model_params)
            model.eval()

            image_list = []  # store image paths
            image_detection_list = []  # store image detection for each image index

            for idx, (image_path_list, images) in enumerate(tqdm.tqdm(cls.federated_detect_loader, desc="detect")):
                images = Variable(images.to(args.device), requires_grad=False)

                with torch.no_grad():
                    outputs = model(images)
                    outputs = non_max_suppression(outputs, conf_threshold=args.conf_threshold, nms_threshold=args.nms_threshold)

                # save images and detections
                image_list.extend(image_path_list)
                image_detection_list.extend(outputs)

            annotation_list = []
            for idx, (image_path, image_detection) in enumerate(zip(image_list, image_detection_list)):
                logger.info("(%d) image: (%s)" % (idx, image_path))

                # find the image id corresponding to the image path
                image = np.array(Image.open(image_path))
                if image_path.split('/')[-1] in cls.detect_namelist:
                    index = cls.detect_namelist.index(image_path.split('/')[-1])
                    image_id = cls.detect_json_data['images'][index]['id']

                # draw bounding boxes and labels of image detections
                if image_detection is not None:
                    # rescale boxes to original image
                    image_detection = rescale_boxes(image_detection, args.image_size, image.shape[:2])
                    for x1, y1, x2, y2, conf, cls_conf, cls_pred in image_detection:
                        annotation_item_dict = {'image_id': image_id, 'category_id': int(cls_pred) + 1,
                                                'bbox': [x1.item(), y1.item(), x2.item() - x1.item(),
                                                         y2.item() - y1.item()], 'score': conf.item()}
                        annotation_list.append(annotation_item_dict)

            annotation_json_str = Common.get_json_by_dict_func(annotation_list)

            with timer("writing to {}".format(args.submission_file_path), logger):
                with open(args.submission_file_path, 'w', encoding='utf-8') as json_file:
                    json_file.writelines(annotation_json_str)

            return Common.get_bytes_by_pickle_object_func({"test": True})
 def call_federated_detect(cls):
     """ send the best model to all the clients for detecting after the federated training """
     pickled_best_model_params = Common.get_bytes_by_pickle_object_func(cls.best_model_params)
     with timer("call federated detect", logger):
         detect_jobs = [gevent.spawn(RequestApi.request, method="POST", url=federated_detect_url,
                                     params={"best_model_params": pickled_best_model_params},
                                     custom_headers={"Content-Type": "multipart/form-data"}) for federated_detect_url in
                      cls.federated_detect_urls]
         gevent.joinall(detect_jobs, timeout=args.timeout)
    def call_federated_train_size(cls):
        """ get the training data ratio of each client """
        with timer("call federated train size", logger):
            train_jobs = [gevent.spawn(RequestApi.request, method="GET", url=federated_train_size_url,
                                       custom_headers={"Content-Type": "application/json"}) for federated_train_size_url
                          in cls.federated_train_size_urls]
            gevent.joinall(train_jobs, timeout=args.timeout)

            for train_job in train_jobs:
                federated_train_size = Common.get_dict_by_json_str_func(train_job.value['data'])["federated_train_size"]
                cls.client_ratio_lst.append(federated_train_size)

            logger.info("before normalization: client_ratio_lst: {}".format(cls.client_ratio_lst))
            client_ratio_sum = sum(cls.client_ratio_lst)
            cls.client_ratio_lst = [ratio / client_ratio_sum for ratio in cls.client_ratio_lst]
            logger.info("after normalization: client_ratio_lst: {}".format(cls.client_ratio_lst))
示例#5
0
    def test(cls, test_model_params=None, mode=None):
        """ get the valid / test results of the federated averaging model on the valid / test set """
        with timer("{} for user id {}".format("evaluate" if (mode == "valid") else "test", args.user_id), logger):
            model = cls.get_model(model_params=test_model_params)
            model.eval()

            label_list = []
            metric_list = []
            for idx, (_, images, labels) in enumerate(
                    tqdm.tqdm(cls.federated_valid_loader if (mode == "valid") else cls.federated_test_loader,
                              desc=mode)):
                images = Variable(images.to(args.device), requires_grad=False)

                # extract labels
                label_list += labels[:, 1].tolist()
                # rescale target
                labels[:, 2:] = xywh2xyxy(labels[:, 2:])
                labels[:, 2:] *= args.image_size

                with torch.no_grad():
                    outputs = model(images)
                    outputs = non_max_suppression(outputs, conf_threshold=args.conf_threshold, nms_threshold=args.nms_threshold)

                metric_list += get_batch_statistics(outputs, labels, iou_threshold=args.iou_threshold)

            if len(metric_list) != 0:
                true_positives, pred_scores, pred_labels = [np.concatenate(x, 0) for x in list(zip(*metric_list))]
                precision, recall, AP, f1, ap_class = ap_per_class(true_positives, pred_scores, pred_labels, label_list)
            else:
                logger.info("the metric list is empty!")

            if args.verbose and (len(metric_list) != 0):
                # logger.info class APs and mAP
                ap_table = [["Index", "Class name", "Precision", "Recall", "F1", "AP"]]
                for i, c in enumerate(ap_class):
                    ap_table += [
                        [c, cls.class_names[c], "%.5f" % precision[i], "%.5f" % recall[i], "%.5f" % f1[i],
                         "%.5f" % AP[i]]]
                ap_table += [[len(ap_class), "Total", "%.5f" % precision.mean(), "%.5f" % recall.mean(),
                              "%.5f" % f1.mean(), "%.5f" % AP.mean()]]
                logger.info(AsciiTable(ap_table).table)
                logger.info(f"---- mAP {AP.mean()}")

            return Common.get_bytes_by_pickle_object_func({"test": True})
示例#6
0
    def train(cls, server_model_params=None, epoch=None):
        """ the current client does federated training on the local training dataset
            based on the current server epoch and the latest server model parameters
        """
        logger.info("start user_id {} training for epoch {}!".format(
            args.user_id, epoch))
        model = cls.get_model(model_params=server_model_params)
        model.train()

        optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

        epo_avg_loss = 0.0
        for epo in range(1, args.client_epoch + 1):
            with timer(
                    'federated train for epoch {}/{}, idx {}/{}, epo {}/{}'.
                    format(epoch, args.server_epoch, args.user_id,
                           args.user_num, epo, args.client_epoch), logger):
                avg_loss = 0.0
                for batch_idx, (_, images, labels) in enumerate(
                        cls.federated_train_loader):
                    done_batches = len(
                        cls.federated_train_loader) * epoch + batch_idx

                    images = Variable(images.to(args.device))
                    labels = Variable(labels.to(args.device),
                                      requires_grad=False)

                    loss, outputs = model(images, labels)
                    loss.backward()

                    if done_batches % args.gradient_accumulations:
                        # accumulates gradient before each step
                        optimizer.step()
                        optimizer.zero_grad()

                    if args.verbose and batch_idx % args.log_interval == 0:
                        log_str = '\nfederated train for [{}/{} ({:.0f}%)]\tloss: {:.6f}\n'.format(
                            batch_idx * len(images),
                            len(cls.federated_train_loader.dataset),
                            100. * batch_idx / len(cls.federated_train_loader),
                            loss.item())
                        metric_table = [[
                            "Metrics", *[
                                f"YOLO Layer {i}"
                                for i in range(len(model.yolo_layers))
                            ]
                        ]]

                        # log metrics at each YOLO layer
                        for i, metric in enumerate(cls.metrics):
                            formats = {m: "%.6f" for m in cls.metrics}
                            formats["grid_size"] = "%2d"
                            formats["cls_acc"] = "%.2f%%"
                            row_metrics = [
                                formats[metric] % yolo.metrics.get(metric, 0)
                                for yolo in model.yolo_layers
                            ]
                            metric_table += [[metric, *row_metrics]]

                        log_str += AsciiTable(metric_table).table

                        logger.info(log_str)

                        model.seen += images.size(0)

                    # update the average training loss for the client based on the batch
                    avg_loss += (loss.item() - avg_loss) / (batch_idx + 1)
                logger.info('epo {:3d}, average loss {:.3f}'.format(
                    epo, avg_loss))

                if epo % args.evaluation_interval == 0:
                    # evaluate the model on the validation set
                    client_model_params = cls.get_model_params(model=model)
                    cls().test(test_model_params=client_model_params,
                               mode="valid")

                if epo % args.checkpoint_interval == 0:
                    pickle.dump(model.state_dict(),
                                open(args.model_file_path, "wb"))

                # update the average training loss for the client
                epo_avg_loss += (avg_loss - epo_avg_loss) / epo
        logger.info('user id {:3d}, average loss {:.3f}'.format(
            args.user_id, epo_avg_loss))

        model = model.to(torch.device('cpu'))
        client_model_params = cls.get_model_params(model=model)
        cls().test(test_model_params=client_model_params, mode="test")

        return Common.get_bytes_by_pickle_object_func({
            "client_model_params":
            client_model_params,
            "epo_avg_loss":
            epo_avg_loss
        })