def call_federated_train(cls): """ call the model of each client for federated training """ with timer("call federated train", logger): train_loss = [] best_epoch = None best_loss = float('inf') for epoch in range(1, args.server_epoch + 1): with timer('train for epoch {}/{}'.format(epoch, args.server_epoch), logger): pickled_epoch = Common.get_bytes_by_pickle_object_func(epoch) pickled_server_model_params = Common.get_bytes_by_pickle_object_func(cls.server_model_params) train_jobs = [gevent.spawn(RequestApi.request, method="POST", url=federated_train_url, params={"server_epoch": pickled_epoch, "server_model_params": pickled_server_model_params}, custom_headers={"Content-Type": "multipart/form-data"}) for federated_train_url in cls.federated_train_urls] gevent.joinall(train_jobs, timeout=args.timeout) avg_loss = 0.0 client_weight_lst = [] for idx, train_job in enumerate(train_jobs): returned_client_model_params = Common.get_object_by_pickle_bytes_func(train_job.value['data'])[ "client_model_params"] returned_epo_avg_loss = Common.get_object_by_pickle_bytes_func(train_job.value['data'])[ "epo_avg_loss"] # update the average training loss of all clients for the epoch avg_loss += (returned_epo_avg_loss - avg_loss) / (idx + 1) client_weight_lst.append(returned_client_model_params) for key in client_weight_lst[-1].keys(): client_weight_lst[-1][key] = cls.client_ratio_lst[-1] * client_weight_lst[-1][key] for idx in range(0, len(client_weight_lst) - 1): client_weight_lst[-1][key] += cls.client_ratio_lst[idx] * client_weight_lst[idx][key] cls.server_model_params = client_weight_lst[-1] logger.info('epoch {:3d}, average loss {:.3f}'.format(epoch, avg_loss)) train_loss.append(avg_loss) # save the model, loss and epoch with the smallest training average loss for all the epochs if avg_loss < best_loss: best_loss = avg_loss best_epoch = epoch cls.best_model_params = cls.server_model_params logger.info("best train loss: {}".format(best_loss)) logger.info("best epoch: {}".format(best_epoch)) pickle.dump(cls.best_model_params, open(args.model_file_path, "wb"))
def detect(cls, detect_model_params: object): """ get the test results and save the json file of the federated averaging model on the test set """ with timer("detect for user id {}".format(args.user_id), logger): model = cls.get_model(model_params=detect_model_params) model.eval() image_list = [] # store image paths image_detection_list = [] # store image detection for each image index for idx, (image_path_list, images) in enumerate(tqdm.tqdm(cls.federated_detect_loader, desc="detect")): images = Variable(images.to(args.device), requires_grad=False) with torch.no_grad(): outputs = model(images) outputs = non_max_suppression(outputs, conf_threshold=args.conf_threshold, nms_threshold=args.nms_threshold) # save images and detections image_list.extend(image_path_list) image_detection_list.extend(outputs) annotation_list = [] for idx, (image_path, image_detection) in enumerate(zip(image_list, image_detection_list)): logger.info("(%d) image: (%s)" % (idx, image_path)) # find the image id corresponding to the image path image = np.array(Image.open(image_path)) if image_path.split('/')[-1] in cls.detect_namelist: index = cls.detect_namelist.index(image_path.split('/')[-1]) image_id = cls.detect_json_data['images'][index]['id'] # draw bounding boxes and labels of image detections if image_detection is not None: # rescale boxes to original image image_detection = rescale_boxes(image_detection, args.image_size, image.shape[:2]) for x1, y1, x2, y2, conf, cls_conf, cls_pred in image_detection: annotation_item_dict = {'image_id': image_id, 'category_id': int(cls_pred) + 1, 'bbox': [x1.item(), y1.item(), x2.item() - x1.item(), y2.item() - y1.item()], 'score': conf.item()} annotation_list.append(annotation_item_dict) annotation_json_str = Common.get_json_by_dict_func(annotation_list) with timer("writing to {}".format(args.submission_file_path), logger): with open(args.submission_file_path, 'w', encoding='utf-8') as json_file: json_file.writelines(annotation_json_str) return Common.get_bytes_by_pickle_object_func({"test": True})
def call_federated_detect(cls): """ send the best model to all the clients for detecting after the federated training """ pickled_best_model_params = Common.get_bytes_by_pickle_object_func(cls.best_model_params) with timer("call federated detect", logger): detect_jobs = [gevent.spawn(RequestApi.request, method="POST", url=federated_detect_url, params={"best_model_params": pickled_best_model_params}, custom_headers={"Content-Type": "multipart/form-data"}) for federated_detect_url in cls.federated_detect_urls] gevent.joinall(detect_jobs, timeout=args.timeout)
def call_federated_train_size(cls): """ get the training data ratio of each client """ with timer("call federated train size", logger): train_jobs = [gevent.spawn(RequestApi.request, method="GET", url=federated_train_size_url, custom_headers={"Content-Type": "application/json"}) for federated_train_size_url in cls.federated_train_size_urls] gevent.joinall(train_jobs, timeout=args.timeout) for train_job in train_jobs: federated_train_size = Common.get_dict_by_json_str_func(train_job.value['data'])["federated_train_size"] cls.client_ratio_lst.append(federated_train_size) logger.info("before normalization: client_ratio_lst: {}".format(cls.client_ratio_lst)) client_ratio_sum = sum(cls.client_ratio_lst) cls.client_ratio_lst = [ratio / client_ratio_sum for ratio in cls.client_ratio_lst] logger.info("after normalization: client_ratio_lst: {}".format(cls.client_ratio_lst))
def test(cls, test_model_params=None, mode=None): """ get the valid / test results of the federated averaging model on the valid / test set """ with timer("{} for user id {}".format("evaluate" if (mode == "valid") else "test", args.user_id), logger): model = cls.get_model(model_params=test_model_params) model.eval() label_list = [] metric_list = [] for idx, (_, images, labels) in enumerate( tqdm.tqdm(cls.federated_valid_loader if (mode == "valid") else cls.federated_test_loader, desc=mode)): images = Variable(images.to(args.device), requires_grad=False) # extract labels label_list += labels[:, 1].tolist() # rescale target labels[:, 2:] = xywh2xyxy(labels[:, 2:]) labels[:, 2:] *= args.image_size with torch.no_grad(): outputs = model(images) outputs = non_max_suppression(outputs, conf_threshold=args.conf_threshold, nms_threshold=args.nms_threshold) metric_list += get_batch_statistics(outputs, labels, iou_threshold=args.iou_threshold) if len(metric_list) != 0: true_positives, pred_scores, pred_labels = [np.concatenate(x, 0) for x in list(zip(*metric_list))] precision, recall, AP, f1, ap_class = ap_per_class(true_positives, pred_scores, pred_labels, label_list) else: logger.info("the metric list is empty!") if args.verbose and (len(metric_list) != 0): # logger.info class APs and mAP ap_table = [["Index", "Class name", "Precision", "Recall", "F1", "AP"]] for i, c in enumerate(ap_class): ap_table += [ [c, cls.class_names[c], "%.5f" % precision[i], "%.5f" % recall[i], "%.5f" % f1[i], "%.5f" % AP[i]]] ap_table += [[len(ap_class), "Total", "%.5f" % precision.mean(), "%.5f" % recall.mean(), "%.5f" % f1.mean(), "%.5f" % AP.mean()]] logger.info(AsciiTable(ap_table).table) logger.info(f"---- mAP {AP.mean()}") return Common.get_bytes_by_pickle_object_func({"test": True})
def train(cls, server_model_params=None, epoch=None): """ the current client does federated training on the local training dataset based on the current server epoch and the latest server model parameters """ logger.info("start user_id {} training for epoch {}!".format( args.user_id, epoch)) model = cls.get_model(model_params=server_model_params) model.train() optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) epo_avg_loss = 0.0 for epo in range(1, args.client_epoch + 1): with timer( 'federated train for epoch {}/{}, idx {}/{}, epo {}/{}'. format(epoch, args.server_epoch, args.user_id, args.user_num, epo, args.client_epoch), logger): avg_loss = 0.0 for batch_idx, (_, images, labels) in enumerate( cls.federated_train_loader): done_batches = len( cls.federated_train_loader) * epoch + batch_idx images = Variable(images.to(args.device)) labels = Variable(labels.to(args.device), requires_grad=False) loss, outputs = model(images, labels) loss.backward() if done_batches % args.gradient_accumulations: # accumulates gradient before each step optimizer.step() optimizer.zero_grad() if args.verbose and batch_idx % args.log_interval == 0: log_str = '\nfederated train for [{}/{} ({:.0f}%)]\tloss: {:.6f}\n'.format( batch_idx * len(images), len(cls.federated_train_loader.dataset), 100. * batch_idx / len(cls.federated_train_loader), loss.item()) metric_table = [[ "Metrics", *[ f"YOLO Layer {i}" for i in range(len(model.yolo_layers)) ] ]] # log metrics at each YOLO layer for i, metric in enumerate(cls.metrics): formats = {m: "%.6f" for m in cls.metrics} formats["grid_size"] = "%2d" formats["cls_acc"] = "%.2f%%" row_metrics = [ formats[metric] % yolo.metrics.get(metric, 0) for yolo in model.yolo_layers ] metric_table += [[metric, *row_metrics]] log_str += AsciiTable(metric_table).table logger.info(log_str) model.seen += images.size(0) # update the average training loss for the client based on the batch avg_loss += (loss.item() - avg_loss) / (batch_idx + 1) logger.info('epo {:3d}, average loss {:.3f}'.format( epo, avg_loss)) if epo % args.evaluation_interval == 0: # evaluate the model on the validation set client_model_params = cls.get_model_params(model=model) cls().test(test_model_params=client_model_params, mode="valid") if epo % args.checkpoint_interval == 0: pickle.dump(model.state_dict(), open(args.model_file_path, "wb")) # update the average training loss for the client epo_avg_loss += (avg_loss - epo_avg_loss) / epo logger.info('user id {:3d}, average loss {:.3f}'.format( args.user_id, epo_avg_loss)) model = model.to(torch.device('cpu')) client_model_params = cls.get_model_params(model=model) cls().test(test_model_params=client_model_params, mode="test") return Common.get_bytes_by_pickle_object_func({ "client_model_params": client_model_params, "epo_avg_loss": epo_avg_loss })