def main(): if args.dataset == 'ChestXray-NIHCC': if args.no_fiding: classes = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia', 'No Fiding' ] else: classes = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia' ] elif args.dataset == 'CheXpert-v1.0-small': classes = [ 'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', 'Support Devices' ] else: print('--dataset incorrect') return torch.manual_seed(args.seed) use_gpu = torch.cuda.is_available() if args.use_cpu: use_gpu = False if not args.evaluate: sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt')) else: sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt')) print("==========\nArgs:{}\n==========".format(args)) if use_gpu: print("Currently using GPU {}".format(args.gpu_devices)) cudnn.benchmark = True torch.cuda.manual_seed_all(args.seed) else: print("Currently using CPU (GPU is highly recommended)") pin_memory = True if use_gpu else False print("Initializing dataset: {}".format(args.dataset)) data_transforms = { 'train': transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize(556), transforms.CenterCrop(512), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), 'valid': transforms.Compose([ transforms.Resize(556), transforms.CenterCrop(512), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), } datasetTrain = DatasetGenerator(path_base=args.base_dir, dataset_file='train', transform=data_transforms['train'], dataset_=args.dataset, no_fiding=args.no_fiding) datasetVal = DatasetGenerator(path_base=args.base_dir, dataset_file='valid', transform=data_transforms['valid'], dataset_=args.dataset, no_fiding=args.no_fiding) train_loader = DataLoader(dataset=datasetTrain, batch_size=args.train_batch, shuffle=args.train_shuffle, num_workers=args.workers, pin_memory=pin_memory) valid_loader = DataLoader(dataset=datasetVal, batch_size=args.valid_batch, shuffle=args.valid_shuffle, num_workers=args.workers, pin_memory=pin_memory) with open(args.infos_densenet) as f: cfg = edict(json.load(f)) print('Initializing densenet branch') model_dense = Classifier(cfg) print("Model size: {:.5f}M".format( sum(p.numel() for p in model_dense.parameters()) / 1000000.0)) with open(args.infos_resnet) as f: cfg = edict(json.load(f)) print('Initializing resnet branch') model_res = Classifier(cfg) print("Model size: {:.5f}M".format( sum(p.numel() for p in model_res.parameters()) / 1000000.0)) print('Initializing fusion branch') model_fusion = Fusion(input_size=7424, output_size=len(classes)) print("Model size: {:.5f}M".format( sum(p.numel() for p in model_fusion.parameters()) / 1000000.0)) print("Initializing optimizers") optimizer_dense = init_optim(args.optim, model_dense.parameters(), args.learning_rate, args.weight_decay, args.momentum) optimizer_res = init_optim(args.optim, model_res.parameters(), args.learning_rate, args.weight_decay, args.momentum) optimizer_fusion = init_optim(args.optim, model_fusion.parameters(), args.learning_rate, args.weight_decay, args.momentum) criterion = nn.BCELoss() print("Initializing scheduler: {}".format(args.scheduler)) if args.stepsize > 0: scheduler_dense = init_scheduler(args.scheduler, optimizer_dense, args.stepsize, args.gamma) scheduler_res = init_scheduler(args.scheduler, optimizer_res, args.stepsize, args.gamma) scheduler_fusion = init_scheduler(args.scheduler, optimizer_fusion, args.stepsize, args.gamma) start_epoch = args.start_epoch best_loss = np.inf if args.resume_densenet: checkpoint_dense = torch.load(args.resume_densenet) model_dense.load_state_dict(checkpoint_dense['state_dict']) epoch_dense = checkpoint_dense['epoch'] print("Resuming densenet from epoch {}".format(epoch_dense + 1)) if args.resume_resnet: checkpoint_res = torch.load(args.resume_resnet) model_res.load_state_dict(checkpoint_res['state_dict']) epoch_res = checkpoint_res['epoch'] print("Resuming resnet from epoch {}".format(epoch_res + 1)) if args.resume_fusion: checkpoint_fusion = torch.load(args.resume_fusion) model_fusion.load_state_dict(checkpoint_fusion['state_dict']) epoch_fusion = checkpoint_fusion['epoch'] print("Resuming fusion from epoch {}".format(epoch_fusion + 1)) if use_gpu: model_dense = nn.DataParallel(model_dense).cuda() model_res = nn.DataParallel(model_res).cuda() model_fusion = nn.DataParallel(model_fusion).cuda() if args.evaluate: print("Evaluate only") if args.step == 1: valid('step1', model_dense, model_res, model_fusion, valid_loader, criterion, args.print_freq, classes, cfg, data_transforms['valid']) elif args.step == 2: valid('step2', model_dense, model_res, model_fusion, valid_loader, criterion, args.print_freq, classes, cfg, data_transforms['valid']) elif args.step == 3: valid('step3', model_dense, model_res, model_fusion, valid_loader, criterion, args.print_freq, classes, cfg, data_transforms['valid']) else: print('args.step not found') return if args.step == 1: #################################### DENSENET BRANCH INIT ########################################## start_time = time.time() train_time = 0 best_epoch = 0 print("==> Start training of densenet branch") for p in model_dense.parameters(): p.requires_grad = True for p in model_res.parameters(): p.requires_grad = False for p in model_fusion.parameters(): p.requires_grad = True for epoch in range(start_epoch, args.max_epoch): start_train_time = time.time() train('step1', model_dense, model_res, model_fusion, train_loader, optimizer_dense, optimizer_res, optimizer_fusion, criterion, args.print_freq, epoch, args.max_epoch, cfg, data_transforms['train']) train_time += round(time.time() - start_train_time) if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or ( epoch + 1) == args.max_epoch: print("==> Validation") loss_val = valid('step1', model_dense, model_res, model_fusion, valid_loader, criterion, args.print_freq, classes, cfg, data_transforms['valid']) if args.stepsize > 0: if args.scheduler == 'ReduceLROnPlateau': scheduler_dense.step(loss_val) scheduler_fusion.step(loss_val) else: scheduler_dense.step() scheduler_fusion.step() is_best = loss_val < best_loss if is_best: best_loss = loss_val best_epoch = epoch + 1 if use_gpu: state_dict_dense = model_dense.module.state_dict() state_dict_fusion = model_fusion.module.state_dict() else: state_dict_dense = model_dense.state_dict() state_dict_fusion = model_fusion.state_dict() save_checkpoint( { 'state_dict': state_dict_dense, 'loss': best_loss, 'epoch': epoch, }, is_best, args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'dense') save_checkpoint( { 'state_dict': state_dict_fusion, 'loss': best_loss, 'epoch': epoch, }, is_best, args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'fusion') print("==> Best Validation Loss {:.4%}, achieved at epoch {}".format( best_loss, best_epoch)) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) train_time = str(datetime.timedelta(seconds=train_time)) print( "Dense branch finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}." .format(elapsed, train_time)) #################################### DENSENET BRANCH END ########################################## elif args.step == 2: #################################### RESNET BRANCH INIT ########################################## start_time = time.time() train_time = 0 best_epoch = 0 print("==> Start training of local branch") for p in model_dense.parameters(): p.requires_grad = False for p in model_res.parameters(): p.requires_grad = True for p in model_fusion.parameters(): p.requires_grad = True for epoch in range(start_epoch, args.max_epoch): start_train_time = time.time() train('step2', model_dense, model_res, model_fusion, train_loader, optimizer_dense, optimizer_res, optimizer_fusion, criterion, args.print_freq, epoch, args.max_epoch, cfg, data_transforms['train']) train_time += round(time.time() - start_train_time) if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or ( epoch + 1) == args.max_epoch: print("==> Validation") loss_val = valid('step2', model_dense, model_res, model_fusion, valid_loader, criterion, args.print_freq, classes, cfg, data_transforms['valid']) if args.stepsize > 0: if args.scheduler == 'ReduceLROnPlateau': scheduler_res.step(loss_val) scheduler_fusion.step(loss_val) else: scheduler_res.step() scheduler_fusion.step() is_best = loss_val < best_loss if is_best: best_loss = loss_val best_epoch = epoch + 1 if use_gpu: state_dict_res = model_res.module.state_dict() state_dict_fusion = model_fusion.module.state_dict() else: state_dict_res = model_res.state_dict() state_dict_fusion = model_fusion.state_dict() save_checkpoint( { 'state_dict': state_dict_res, 'loss': best_loss, 'epoch': epoch, }, is_best, args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'res') save_checkpoint( { 'state_dict': state_dict_fusion, 'loss': best_loss, 'epoch': epoch, }, is_best, args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'fusion') print("==> Best Validation Loss {:.4%}, achieved at epoch {}".format( best_loss, best_epoch)) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) train_time = str(datetime.timedelta(seconds=train_time)) print( "Resnet branch finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}." .format(elapsed, train_time)) #################################### RESNET BRANCH END ########################################## elif args.step == 3: #################################### FUSION BRANCH INIT ########################################## start_time = time.time() train_time = 0 best_epoch = 0 print("==> Start training of fusion branch") for p in model_dense.parameters(): p.requires_grad = True for p in model_res.parameters(): p.requires_grad = True for p in model_fusion.parameters(): p.requires_grad = True for epoch in range(start_epoch, args.max_epoch): start_train_time = time.time() train('step3', model_dense, model_res, model_fusion, train_loader, optimizer_dense, optimizer_res, optimizer_fusion, criterion, args.print_freq, epoch, args.max_epoch, cfg, data_transforms['train']) train_time += round(time.time() - start_train_time) if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or ( epoch + 1) == args.max_epoch: print("==> Validation") loss_val = valid('step3', model_dense, model_res, model_fusion, valid_loader, criterion, args.print_freq, classes, cfg, data_transforms['valid']) if args.stepsize > 0: if args.scheduler == 'ReduceLROnPlateau': scheduler_dense.step(loss_val) scheduler_res.step(loss_val) scheduler_fusion.step(loss_val) else: scheduler_dense.step() scheduler_res.step() scheduler_fusion.step() is_best = loss_val < best_loss if is_best: best_loss = loss_val best_epoch = epoch + 1 if use_gpu: state_dict_dense = model_dense.module.state_dict() state_dict_res = model_res.module.state_dict() state_dict_fusion = model_fusion.module.state_dict() else: state_dict_dense = model_dense.state_dict() state_dict_res = model_res.state_dict() state_dict_fusion = model_fusion.state_dict() save_checkpoint( { 'state_dict': state_dict_dense, 'loss': best_loss, 'epoch': epoch, }, is_best, args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'dense') save_checkpoint( { 'state_dict': state_dict_res, 'loss': best_loss, 'epoch': epoch, }, is_best, args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'res') save_checkpoint( { 'state_dict': state_dict_fusion, 'loss': best_loss, 'epoch': epoch, }, is_best, args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'fusion') print("==> Best Validation Loss {:.4%}, achieved at epoch {}".format( best_loss, best_epoch)) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) train_time = str(datetime.timedelta(seconds=train_time)) print( "Fusion branch finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}." .format(elapsed, train_time)) #################################### FUSION BRANCH END ########################################## else: print('args.step not found')
def run_fl(args): with open(args.cfg_path) as f: cfg = edict(json.load(f)) if args.verbose is True: print(json.dumps(cfg, indent=4)) if not os.path.exists(args.save_path): os.mkdir(args.save_path) if args.logtofile is True: logging.basicConfig(filename=args.save_path + '/log.txt', filemode="w", level=logging.INFO) else: logging.basicConfig(level=logging.INFO) if not args.resume: with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f: json.dump(cfg, f, indent=1) device_ids = list(map(int, args.device_ids.split(','))) num_devices = torch.cuda.device_count() if num_devices < len(device_ids): raise Exception('#available gpu : {} < --device_ids : {}'.format( num_devices, len(device_ids))) device = torch.device('cuda:{}'.format(device_ids[0])) # initialise global model model = Classifier(cfg).to(device).train() if args.verbose is True: from torchsummary import summary if cfg.fix_ratio: h, w = cfg.long_side, cfg.long_side else: h, w = cfg.height, cfg.width summary(model.to(device), (3, h, w)) if args.pre_train is not None: if os.path.exists(args.pre_train): ckpt = torch.load(args.pre_train, map_location=device) model.load_state_dict(ckpt) src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../' dst_folder = os.path.join(args.save_path, 'classification') rc, size = subprocess.getstatusoutput('du --max-depth=0 %s | cut -f1' % src_folder) if rc != 0: raise Exception('Copy folder error : {}'.format(rc)) else: print('Successfully determined size of directory') rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' % (src_folder, dst_folder)) if rc != 0: raise Exception('copy folder error : {}'.format(err_msg)) else: print('Successfully copied folder') # copy train files train_files = cfg.train_csv clients = {} for i, c in enumerate(string.ascii_uppercase): if i < len(train_files): clients[c] = {} else: break # initialise clients for i, client in enumerate(clients): copyfile(train_files[i], os.path.join(args.save_path, f'train_{client}.csv')) clients[client]['dataloader_train'] =\ DataLoader( ImageDataset(train_files[i], cfg, mode='train'), batch_size=cfg.train_batch_size, num_workers=args.num_workers,drop_last=True, shuffle=True ) clients[client]['bytes_uploaded'] = 0.0 clients[client]['epoch'] = 0 copyfile(cfg.dev_csv, os.path.join(args.save_path, 'dev.csv')) dataloader_dev = DataLoader(ImageDataset(cfg.dev_csv, cfg, mode='dev'), batch_size=cfg.dev_batch_size, num_workers=args.num_workers, drop_last=False, shuffle=False) dev_header = dataloader_dev.dataset._label_header w_global = model.state_dict() summary_train = {'epoch': 0, 'step': 0} summary_dev = {'loss': float('inf'), 'acc': 0.0} summary_writer = SummaryWriter(args.save_path) comm_rounds = cfg.epoch best_dict = { "acc_dev_best": 0.0, "auc_dev_best": 0.0, "loss_dev_best": float('inf'), "fused_dev_best": 0.0, "best_idx": 1 } # Communication rounds loop for cr in range(comm_rounds): logging.info('{}, Start communication round {} of FL - {} ...'.format( time.strftime("%Y-%m-%d %H:%M:%S"), cr + 1, cfg.fl_technique)) w_locals = [] for client in clients: logging.info( '{}, Start local training process for client {}, communication round: {} ...' .format(time.strftime("%Y-%m-%d %H:%M:%S"), client, cr + 1)) # Load previous current global model as start point model = Classifier(cfg).to(device).train() model.load_state_dict(w_global) if cfg.fl_technique == "FedProx": global_weight_collector = get_global_weights(model, device) else: global_weight_collector = None optimizer = get_optimizer(model.parameters(), cfg) # local training loops for epoch in range(cfg.local_epoch): lr = lr_schedule(cfg.lr, cfg.lr_factor, epoch, cfg.lr_epochs) for param_group in optimizer.param_groups: param_group['lr'] = lr summary_train, best_dict = train_epoch_fl( summary_train, summary_dev, cfg, args, model, clients[client]['dataloader_train'], dataloader_dev, optimizer, summary_writer, best_dict, dev_header, epoch, global_weight_collector) summary_train['step'] += 1 bytes_to_upload = sys.getsizeof(model.state_dict()) clients[client]['bytes_uploaded'] += bytes_to_upload logging.info( '{}, Completed local rounds for client {} in communication round {}. ' 'Uploading {} bytes to server, {} bytes in total sent from client' .format(time.strftime("%Y-%m-%d %H:%M:%S"), client, cr + 1, bytes_to_upload, clients[client]['bytes_uploaded'])) w_locals.append(model.state_dict()) if cfg.fl_technique == "FedAvg": w_global = fed_avg(w_locals) elif cfg.fl_technique == 'WFedAvg': w_global = weighted_fed_avg(w_locals, cfg.train_proportions) elif cfg.fl_technique == 'FedProx': # Use weighted FedAvg when using FedProx w_global = weighted_fed_avg(w_locals, cfg.train_proportions) # Test the performance of the averaged model avged_model = Classifier(cfg).to(device) avged_model.load_state_dict(w_global) time_now = time.time() summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args, avged_model, dataloader_dev) time_spent = time.time() - time_now auclist = [] for i in range(len(cfg.num_classes)): y_pred = predlist[i] y_true = true_list[i] fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1) auc = metrics.auc(fpr, tpr) auclist.append(auc) auc_summary = np.array(auclist) loss_dev_str = ' '.join( map(lambda x: '{:.5f}'.format(x), summary_dev['loss'])) acc_dev_str = ' '.join( map(lambda x: '{:.3f}'.format(x), summary_dev['acc'])) auc_dev_str = ' '.join(map(lambda x: '{:.3f}'.format(x), auc_summary)) logging.info( '{}, Averaged Model -> Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},' 'Mean auc: {:.3f} ' 'Run Time : {:.2f} sec'.format(time.strftime("%Y-%m-%d %H:%M:%S"), summary_train['step'], loss_dev_str, acc_dev_str, auc_dev_str, auc_summary.mean(), time_spent)) for t in range(len(cfg.num_classes)): summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]), summary_dev['loss'][t], summary_train['step']) summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]), summary_dev['acc'][t], summary_train['step']) summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]), auc_summary[t], summary_train['step']) save_best = False mean_acc = summary_dev['acc'][cfg.save_index].mean() if mean_acc >= best_dict['acc_dev_best']: best_dict['acc_dev_best'] = mean_acc if cfg.best_target == 'acc': save_best = True mean_auc = auc_summary[cfg.save_index].mean() if mean_auc >= best_dict['auc_dev_best']: best_dict['auc_dev_best'] = mean_auc if cfg.best_target == 'auc': save_best = True mean_loss = summary_dev['loss'][cfg.save_index].mean() if mean_loss <= best_dict['loss_dev_best']: best_dict['loss_dev_best'] = mean_loss if cfg.best_target == 'loss': save_best = True if save_best: torch.save( { 'epoch': summary_train['epoch'], 'step': summary_train['step'], 'acc_dev_best': best_dict['acc_dev_best'], 'auc_dev_best': best_dict['auc_dev_best'], 'loss_dev_best': best_dict['loss_dev_best'], 'state_dict': avged_model.state_dict() }, os.path.join(args.save_path, 'best{}.ckpt'.format(best_dict['best_idx']))) best_dict['best_idx'] += 1 if best_dict['best_idx'] > cfg.save_top_k: best_dict['best_idx'] = 1 logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},' 'Auc :{},Best Auc : {:.3f}'.format( time.strftime("%Y-%m-%d %H:%M:%S"), summary_train['step'], loss_dev_str, acc_dev_str, auc_dev_str, best_dict['auc_dev_best'])) torch.save( { 'epoch': cr, 'step': summary_train['step'], 'acc_dev_best': best_dict['acc_dev_best'], 'auc_dev_best': best_dict['auc_dev_best'], 'loss_dev_best': best_dict['loss_dev_best'], 'state_dict': avged_model.state_dict() }, os.path.join(args.save_path, 'train.ckpt'))