Пример #1
0
def run(args):
    with open(args.cfg_path) as f:
        cfg = edict(json.load(f))
        if args.verbose is True:
            print(json.dumps(cfg, indent=4))

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    if args.logtofile is True:
        logging.basicConfig(filename=args.save_path + '/log.txt',
                            filemode="w",
                            level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)

    if not args.resume:
        with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=1)

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    model = Classifier(cfg)
    if args.verbose is True:
        from torchsummary import summary
        if cfg.fix_ratio:
            h, w = cfg.long_side, cfg.long_side
        else:
            h, w = cfg.height, cfg.width
        summary(model.to(device), (3, h, w))
    model = DataParallel(model, device_ids=device_ids).to(device).train()
    if args.pre_train is not None:
        if os.path.exists(args.pre_train):
            ckpt = torch.load(args.pre_train, map_location=device)
            model.module.load_state_dict(ckpt)
    optimizer = get_optimizer(model.parameters(), cfg)

    src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../'
    dst_folder = os.path.join(args.save_path, 'classification')
    # rc, size = subprocess.getstatusoutput('dir --max-depth=0 %s | cut -f1'
    #                                       % src_folder)
    # if rc != 0:
    #     print(size)
    #     raise Exception('Copy folder error : {}'.format(rc))
    # rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' % (src_folder,
    #                                                           dst_folder))
    # if rc != 0:
    #     raise Exception('copy folder error : {}'.format(err_msg))

    copyfile(cfg.train_csv, os.path.join(args.save_path, 'train.csv'))
    copyfile(cfg.dev_csv, os.path.join(args.save_path, 'valid.csv'))

    dataloader_train = DataLoader(ImageDataset(cfg.train_csv,
                                               cfg,
                                               mode='train'),
                                  batch_size=cfg.train_batch_size,
                                  num_workers=args.num_workers,
                                  drop_last=True,
                                  shuffle=True)
    dataloader_dev = DataLoader(ImageDataset(cfg.dev_csv, cfg, mode='dev'),
                                batch_size=cfg.dev_batch_size,
                                num_workers=args.num_workers,
                                drop_last=False,
                                shuffle=False)
    dev_header = dataloader_dev.dataset._label_header

    summary_train = {'epoch': 0, 'step': 0}
    summary_dev = {'loss': float('inf'), 'acc': 0.0}
    summary_writer = SummaryWriter(args.save_path)
    epoch_start = 0
    best_dict = {
        "acc_dev_best": 0.0,
        "auc_dev_best": 0.0,
        "loss_dev_best": float('inf'),
        "fused_dev_best": 0.0,
        "best_idx": 1
    }

    if args.resume:
        ckpt_path = os.path.join(args.save_path, 'train.ckpt')
        ckpt = torch.load(ckpt_path, map_location=device)
        model.module.load_state_dict(ckpt['state_dict'])
        summary_train = {'epoch': ckpt['epoch'], 'step': ckpt['step']}
        best_dict['acc_dev_best'] = ckpt['acc_dev_best']
        best_dict['loss_dev_best'] = ckpt['loss_dev_best']
        best_dict['auc_dev_best'] = ckpt['auc_dev_best']
        epoch_start = ckpt['epoch']

    for epoch in range(epoch_start, cfg.epoch):
        lr = lr_schedule(cfg.lr, cfg.lr_factor, summary_train['epoch'],
                         cfg.lr_epochs)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        summary_train, best_dict = train_epoch(summary_train, summary_dev, cfg,
                                               args, model, dataloader_train,
                                               dataloader_dev, optimizer,
                                               summary_writer, best_dict,
                                               dev_header)

        time_now = time.time()
        summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args,
                                                      model, dataloader_dev)
        time_spent = time.time() - time_now

        auclist = []
        for i in range(len(cfg.num_classes)):
            y_pred = predlist[i]
            y_true = true_list[i]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_pred,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            auclist.append(auc)
        summary_dev['auc'] = np.array(auclist)

        loss_dev_str = ' '.join(
            map(lambda x: '{:.5f}'.format(x), summary_dev['loss']))
        acc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['acc']))
        auc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['auc']))

        logging.info('{}, Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},'
                     'Mean auc: {:.3f} '
                     'Run Time : {:.2f} sec'.format(
                         time.strftime("%Y-%m-%d %H:%M:%S"),
                         summary_train['step'], loss_dev_str, acc_dev_str,
                         auc_dev_str, summary_dev['auc'].mean(), time_spent))

        for t in range(len(cfg.num_classes)):
            summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]),
                                      summary_dev['loss'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]),
                                      summary_dev['acc'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]),
                                      summary_dev['auc'][t],
                                      summary_train['step'])

        save_best = False

        mean_acc = summary_dev['acc'][cfg.save_index].mean()
        if mean_acc >= best_dict['acc_dev_best']:
            best_dict['acc_dev_best'] = mean_acc
            if cfg.best_target == 'acc':
                save_best = True

        mean_auc = summary_dev['auc'][cfg.save_index].mean()
        if mean_auc >= best_dict['auc_dev_best']:
            best_dict['auc_dev_best'] = mean_auc
            if cfg.best_target == 'auc':
                save_best = True

        mean_loss = summary_dev['loss'][cfg.save_index].mean()
        if mean_loss <= best_dict['loss_dev_best']:
            best_dict['loss_dev_best'] = mean_loss
            if cfg.best_target == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'acc_dev_best': best_dict['acc_dev_best'],
                    'auc_dev_best': best_dict['auc_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': model.module.state_dict()
                },
                os.path.join(args.save_path,
                             'best{}.ckpt'.format(best_dict['best_idx'])))
            best_dict['best_idx'] += 1
            if best_dict['best_idx'] > cfg.save_top_k:
                best_dict['best_idx'] = 1
            logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},'
                         'Auc :{},Best Auc : {:.3f}'.format(
                             time.strftime("%Y-%m-%d %H:%M:%S"),
                             summary_train['step'], loss_dev_str, acc_dev_str,
                             auc_dev_str, best_dict['auc_dev_best']))
        torch.save(
            {
                'epoch': summary_train['epoch'],
                'step': summary_train['step'],
                'acc_dev_best': best_dict['acc_dev_best'],
                'auc_dev_best': best_dict['auc_dev_best'],
                'loss_dev_best': best_dict['loss_dev_best'],
                'state_dict': model.module.state_dict()
            }, os.path.join(args.save_path, 'train.ckpt'))
    summary_writer.close()
Пример #2
0
def run(args, val_h5_file):
    with open(args.cfg_path) as f:
        cfg = edict(json.load(f))
        if args.verbose is True:
            print(json.dumps(cfg, indent=4))

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    if args.logtofile is True:
        logging.basicConfig(filename=args.save_path + '/log.txt',
                            filemode="w",
                            level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)

    if not args.resume:
        with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=1)

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    model = Classifier(cfg)
    if args.verbose is True:
        from torchsummary import summary
        if cfg.fix_ratio:
            h, w = cfg.long_side, cfg.long_side
        else:
            h, w = cfg.height, cfg.width
        summary(model.to(device), (3, h, w))
    model = DataParallel(model, device_ids=device_ids).to(device).train()
    if args.pre_train is not None:
        if os.path.exists(args.pre_train):
            ckpt = torch.load(args.pre_train, map_location=device)
            model.module.load_state_dict(ckpt)
    optimizer = get_optimizer(model.parameters(), cfg)

    #src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../'
    #dst_folder = os.path.join(args.save_path, 'classification')
    #rc, size = subprocess.getstatusoutput('du --max-depth=0 %s | cut -f1' % src_folder)
    #if rc != 0: raise Exception('Copy folder error : {}'.format(rc))
    #rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' % (src_folder, dst_folder))
    #if rc != 0: raise Exception('copy folder error : {}'.format(err_msg))
    #copyfile(cfg.train_csv, os.path.join(args.save_path, 'train.csv'))
    #copyfile(cfg.dev_csv, os.path.join(args.save_path, 'dev.csv'))
    # np_train_h5_file = np.array(train_h5_file['train'][:10000], dtype=np.uint8)
    # np_t_u_ones = np.array(train_h5_file['train_u_ones'][:10000], dtype=np.int8)
    # np_t_u_zeros = np.array(train_h5_file['train_u_zeros'][:10000], dtype=np.int8)
    # np_t_u_random = np.array(train_h5_file['train_u_random'][:10000], dtype=np.int8)

    np_val_h5_file = np.array(val_h5_file['val'], dtype=np.uint8)
    np_v_u_ones = np.array(val_h5_file['val_u_ones'], dtype=np.int8)
    np_v_u_zeros = np.array(val_h5_file['val_u_zeros'], dtype=np.int8)
    np_v_u_random = np.array(val_h5_file['val_u_random'], dtype=np.int8)

    train_labels = {}
    with h5py.File(f'{args.train_chunks}/train_labels.h5', 'r') as fp:
        train_labels['train_u_ones'] = np.array(fp['train_u_ones'],
                                                dtype=np.int8)
        train_labels['train_u_zeros'] = np.array(fp['train_u_zeros'],
                                                 dtype=np.int8)
        train_labels['train_u_random'] = np.array(fp['train_u_random'],
                                                  dtype=np.int8)
    np_train_samples = None
    for i in range(args.chunk_count):
        with open(f'{args.train_chunks}/chexpert_dset_chunk_{i+1}.npy',
                  'rb') as f:
            if np_train_samples is None:
                np_train_samples = np.load(f)
            else:
                np_train_samples = np.concatenate(
                    (np_train_samples, np.load(f)))

    dataloader_train = DataLoader(ImageDataset(
        [np_train_samples, train_labels], cfg, mode='train'),
                                  batch_size=cfg.train_batch_size,
                                  num_workers=args.num_workers,
                                  drop_last=True,
                                  shuffle=True)

    dataloader_dev = DataLoader(ImageDataset(
        [np_val_h5_file, np_v_u_zeros, np_v_u_ones, np_v_u_random],
        cfg,
        mode='val'),
                                batch_size=cfg.dev_batch_size,
                                num_workers=args.num_workers,
                                drop_last=False,
                                shuffle=False)
    #dev_header = dataloader_dev.dataset._label_header
    dev_header = [
        'No_Finding', 'Enlarged_Cardiomediastinum', 'Cardiomegaly',
        'Lung_Opacity', 'Lung_Lesion', 'Edema', 'Consolidation', 'Pneumonia',
        'Atelectasis', 'Pneumothorax', 'Pleural_Effusion', 'Pleural_Other',
        'Fracture', 'Support_Devices'
    ]
    print(f'dataloaders are set. train count: {np_train_samples.shape[0]}')
    logging.info("[LOGGING TEST]: dataloaders are set...")
    summary_train = {'epoch': 0, 'step': 0}
    summary_dev = {'loss': float('inf'), 'acc': 0.0}
    summary_writer = SummaryWriter(args.save_path)
    epoch_start = 0
    best_dict = {
        "acc_dev_best": 0.0,
        "auc_dev_best": 0.0,
        "loss_dev_best": float('inf'),
        "fused_dev_best": 0.0,
        "best_idx": 1
    }

    if args.resume:
        ckpt_path = os.path.join(args.save_path, 'train.ckpt')
        ckpt = torch.load(ckpt_path, map_location=device)
        model.module.load_state_dict(ckpt['state_dict'])
        summary_train = {'epoch': ckpt['epoch'], 'step': ckpt['step']}
        best_dict['acc_dev_best'] = ckpt['acc_dev_best']
        best_dict['loss_dev_best'] = ckpt['loss_dev_best']
        best_dict['auc_dev_best'] = ckpt['auc_dev_best']
        epoch_start = ckpt['epoch']

    q_list = []
    k_list = []
    for i in range(len(cfg.num_classes)):
        q_list.append(args.q)
        k_list.append(args.k)

    k_list = torch.FloatTensor(k_list)
    q_list = torch.FloatTensor(q_list)
    loss_sq_hinge = MultiClassSquaredHingeLoss()
    print('Everything is set starting to train...')
    before = datetime.datetime.now()
    for epoch in range(epoch_start, cfg.epoch):
        lr = lr_schedule(cfg.lr, cfg.lr_factor, summary_train['epoch'],
                         cfg.lr_epochs)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        summary_train, best_dict = train_epoch(summary_train, summary_dev, cfg,
                                               args, model, dataloader_train,
                                               dataloader_dev, optimizer,
                                               summary_writer, best_dict,
                                               dev_header, q_list, k_list,
                                               loss_sq_hinge)

        time_now = time.time()
        summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args,
                                                      model, dataloader_dev,
                                                      q_list, k_list,
                                                      loss_sq_hinge)
        time_spent = time.time() - time_now

        auclist = []
        for i in range(len(cfg.num_classes)):
            y_pred = predlist[i]
            y_true = true_list[i]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_pred,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            auclist.append(auc)
        summary_dev['auc'] = np.array(auclist)

        loss_dev_str = ' '.join(
            map(lambda x: '{:.5f}'.format(x), summary_dev['loss']))
        acc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['acc']))
        auc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['auc']))

        logging.info('{}, Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},'
                     'Mean auc: {:.3f} '
                     'Run Time : {:.2f} sec'.format(
                         time.strftime("%Y-%m-%d %H:%M:%S"),
                         summary_train['step'], loss_dev_str, acc_dev_str,
                         auc_dev_str, summary_dev['auc'].mean(), time_spent))

        for t in range(len(cfg.num_classes)):
            summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]),
                                      summary_dev['loss'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]),
                                      summary_dev['acc'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]),
                                      summary_dev['auc'][t],
                                      summary_train['step'])

        save_best = False

        mean_acc = summary_dev['acc'][cfg.save_index].mean()
        if mean_acc >= best_dict['acc_dev_best']:
            best_dict['acc_dev_best'] = mean_acc
            if cfg.best_target == 'acc':
                save_best = True

        mean_auc = summary_dev['auc'][cfg.save_index].mean()
        if mean_auc >= best_dict['auc_dev_best']:
            best_dict['auc_dev_best'] = mean_auc
            if cfg.best_target == 'auc':
                save_best = True

        mean_loss = summary_dev['loss'][cfg.save_index].mean()
        if mean_loss <= best_dict['loss_dev_best']:
            best_dict['loss_dev_best'] = mean_loss
            if cfg.best_target == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'acc_dev_best': best_dict['acc_dev_best'],
                    'auc_dev_best': best_dict['auc_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': model.module.state_dict()
                },
                os.path.join(args.save_path,
                             'best{}.ckpt'.format(best_dict['best_idx'])))
            best_dict['best_idx'] += 1
            if best_dict['best_idx'] > cfg.save_top_k:
                best_dict['best_idx'] = 1
            logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},'
                         'Auc :{},Best Auc : {:.3f}'.format(
                             time.strftime("%Y-%m-%d %H:%M:%S"),
                             summary_train['step'], loss_dev_str, acc_dev_str,
                             auc_dev_str, best_dict['auc_dev_best']))
        torch.save(
            {
                'epoch': summary_train['epoch'],
                'step': summary_train['step'],
                'acc_dev_best': best_dict['acc_dev_best'],
                'auc_dev_best': best_dict['auc_dev_best'],
                'loss_dev_best': best_dict['loss_dev_best'],
                'state_dict': model.module.state_dict()
            }, os.path.join(args.save_path, 'train.ckpt'))

        print_remaining_time(before,
                             epoch + 1,
                             cfg.epoch,
                             additional='[training]')
    summary_writer.close()
Пример #3
0
def main():
    if args.dataset == 'ChestXray-NIHCC':
        if args.no_fiding:
            classes = [
                'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
                'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
                'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening',
                'Hernia', 'No Fiding'
            ]
        else:
            classes = [
                'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration',
                'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation',
                'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening',
                'Hernia'
            ]
    elif args.dataset == 'CheXpert-v1.0-small':
        classes = [
            'No Finding', 'Enlarged Cardiomediastinum', 'Cardiomegaly',
            'Lung Opacity', 'Lung Lesion', 'Edema', 'Consolidation',
            'Pneumonia', 'Atelectasis', 'Pneumothorax', 'Pleural Effusion',
            'Pleural Other', 'Fracture', 'Support Devices'
        ]
    else:
        print('--dataset incorrect')
        return

    torch.manual_seed(args.seed)
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    pin_memory = True if use_gpu else False

    print("Initializing dataset: {}".format(args.dataset))

    data_transforms = {
        'train':
        transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.Resize(556),
            transforms.CenterCrop(512),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
        'valid':
        transforms.Compose([
            transforms.Resize(556),
            transforms.CenterCrop(512),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ]),
    }

    datasetTrain = DatasetGenerator(path_base=args.base_dir,
                                    dataset_file='train',
                                    transform=data_transforms['train'],
                                    dataset_=args.dataset,
                                    no_fiding=args.no_fiding)

    datasetVal = DatasetGenerator(path_base=args.base_dir,
                                  dataset_file='valid',
                                  transform=data_transforms['valid'],
                                  dataset_=args.dataset,
                                  no_fiding=args.no_fiding)

    train_loader = DataLoader(dataset=datasetTrain,
                              batch_size=args.train_batch,
                              shuffle=args.train_shuffle,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    valid_loader = DataLoader(dataset=datasetVal,
                              batch_size=args.valid_batch,
                              shuffle=args.valid_shuffle,
                              num_workers=args.workers,
                              pin_memory=pin_memory)

    with open(args.infos_densenet) as f:
        cfg = edict(json.load(f))

    print('Initializing densenet branch')
    model_dense = Classifier(cfg)
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model_dense.parameters()) / 1000000.0))

    with open(args.infos_resnet) as f:
        cfg = edict(json.load(f))

    print('Initializing resnet branch')
    model_res = Classifier(cfg)
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model_res.parameters()) / 1000000.0))

    print('Initializing fusion branch')
    model_fusion = Fusion(input_size=7424, output_size=len(classes))
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model_fusion.parameters()) / 1000000.0))

    print("Initializing optimizers")
    optimizer_dense = init_optim(args.optim, model_dense.parameters(),
                                 args.learning_rate, args.weight_decay,
                                 args.momentum)
    optimizer_res = init_optim(args.optim, model_res.parameters(),
                               args.learning_rate, args.weight_decay,
                               args.momentum)
    optimizer_fusion = init_optim(args.optim, model_fusion.parameters(),
                                  args.learning_rate, args.weight_decay,
                                  args.momentum)

    criterion = nn.BCELoss()

    print("Initializing scheduler: {}".format(args.scheduler))
    if args.stepsize > 0:
        scheduler_dense = init_scheduler(args.scheduler, optimizer_dense,
                                         args.stepsize, args.gamma)
        scheduler_res = init_scheduler(args.scheduler, optimizer_res,
                                       args.stepsize, args.gamma)
        scheduler_fusion = init_scheduler(args.scheduler, optimizer_fusion,
                                          args.stepsize, args.gamma)

    start_epoch = args.start_epoch
    best_loss = np.inf

    if args.resume_densenet:
        checkpoint_dense = torch.load(args.resume_densenet)
        model_dense.load_state_dict(checkpoint_dense['state_dict'])
        epoch_dense = checkpoint_dense['epoch']
        print("Resuming densenet from epoch {}".format(epoch_dense + 1))

    if args.resume_resnet:
        checkpoint_res = torch.load(args.resume_resnet)
        model_res.load_state_dict(checkpoint_res['state_dict'])
        epoch_res = checkpoint_res['epoch']
        print("Resuming resnet from epoch {}".format(epoch_res + 1))

    if args.resume_fusion:
        checkpoint_fusion = torch.load(args.resume_fusion)
        model_fusion.load_state_dict(checkpoint_fusion['state_dict'])
        epoch_fusion = checkpoint_fusion['epoch']
        print("Resuming fusion from epoch {}".format(epoch_fusion + 1))

    if use_gpu:
        model_dense = nn.DataParallel(model_dense).cuda()
        model_res = nn.DataParallel(model_res).cuda()
        model_fusion = nn.DataParallel(model_fusion).cuda()

    if args.evaluate:
        print("Evaluate only")
        if args.step == 1:
            valid('step1', model_dense, model_res, model_fusion, valid_loader,
                  criterion, args.print_freq, classes, cfg,
                  data_transforms['valid'])
        elif args.step == 2:
            valid('step2', model_dense, model_res, model_fusion, valid_loader,
                  criterion, args.print_freq, classes, cfg,
                  data_transforms['valid'])
        elif args.step == 3:
            valid('step3', model_dense, model_res, model_fusion, valid_loader,
                  criterion, args.print_freq, classes, cfg,
                  data_transforms['valid'])
        else:
            print('args.step not found')
        return

    if args.step == 1:
        #################################### DENSENET BRANCH INIT ##########################################
        start_time = time.time()
        train_time = 0
        best_epoch = 0
        print("==> Start training of densenet branch")

        for p in model_dense.parameters():
            p.requires_grad = True

        for p in model_res.parameters():
            p.requires_grad = False

        for p in model_fusion.parameters():
            p.requires_grad = True

        for epoch in range(start_epoch, args.max_epoch):
            start_train_time = time.time()
            train('step1', model_dense, model_res, model_fusion, train_loader,
                  optimizer_dense, optimizer_res, optimizer_fusion, criterion,
                  args.print_freq, epoch, args.max_epoch, cfg,
                  data_transforms['train'])
            train_time += round(time.time() - start_train_time)
            if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (
                    epoch + 1) == args.max_epoch:
                print("==> Validation")
                loss_val = valid('step1', model_dense, model_res, model_fusion,
                                 valid_loader, criterion, args.print_freq,
                                 classes, cfg, data_transforms['valid'])

                if args.stepsize > 0:
                    if args.scheduler == 'ReduceLROnPlateau':
                        scheduler_dense.step(loss_val)
                        scheduler_fusion.step(loss_val)
                    else:
                        scheduler_dense.step()
                        scheduler_fusion.step()

                is_best = loss_val < best_loss
                if is_best:
                    best_loss = loss_val
                    best_epoch = epoch + 1

                if use_gpu:
                    state_dict_dense = model_dense.module.state_dict()
                    state_dict_fusion = model_fusion.module.state_dict()
                else:
                    state_dict_dense = model_dense.state_dict()
                    state_dict_fusion = model_fusion.state_dict()

                save_checkpoint(
                    {
                        'state_dict': state_dict_dense,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'dense')
                save_checkpoint(
                    {
                        'state_dict': state_dict_fusion,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'fusion')

        print("==> Best Validation Loss {:.4%}, achieved at epoch {}".format(
            best_loss, best_epoch))

        elapsed = round(time.time() - start_time)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        train_time = str(datetime.timedelta(seconds=train_time))
        print(
            "Dense branch finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}."
            .format(elapsed, train_time))
        #################################### DENSENET BRANCH END ##########################################

    elif args.step == 2:
        #################################### RESNET BRANCH INIT ##########################################
        start_time = time.time()
        train_time = 0
        best_epoch = 0
        print("==> Start training of local branch")

        for p in model_dense.parameters():
            p.requires_grad = False

        for p in model_res.parameters():
            p.requires_grad = True

        for p in model_fusion.parameters():
            p.requires_grad = True

        for epoch in range(start_epoch, args.max_epoch):
            start_train_time = time.time()
            train('step2', model_dense, model_res, model_fusion, train_loader,
                  optimizer_dense, optimizer_res, optimizer_fusion, criterion,
                  args.print_freq, epoch, args.max_epoch, cfg,
                  data_transforms['train'])
            train_time += round(time.time() - start_train_time)
            if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (
                    epoch + 1) == args.max_epoch:
                print("==> Validation")
                loss_val = valid('step2', model_dense, model_res, model_fusion,
                                 valid_loader, criterion, args.print_freq,
                                 classes, cfg, data_transforms['valid'])

                if args.stepsize > 0:
                    if args.scheduler == 'ReduceLROnPlateau':
                        scheduler_res.step(loss_val)
                        scheduler_fusion.step(loss_val)
                    else:
                        scheduler_res.step()
                        scheduler_fusion.step()

                is_best = loss_val < best_loss
                if is_best:
                    best_loss = loss_val
                    best_epoch = epoch + 1

                if use_gpu:
                    state_dict_res = model_res.module.state_dict()
                    state_dict_fusion = model_fusion.module.state_dict()
                else:
                    state_dict_res = model_res.state_dict()
                    state_dict_fusion = model_fusion.state_dict()

                save_checkpoint(
                    {
                        'state_dict': state_dict_res,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'res')
                save_checkpoint(
                    {
                        'state_dict': state_dict_fusion,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'fusion')

        print("==> Best Validation Loss {:.4%}, achieved at epoch {}".format(
            best_loss, best_epoch))

        elapsed = round(time.time() - start_time)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        train_time = str(datetime.timedelta(seconds=train_time))
        print(
            "Resnet branch finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}."
            .format(elapsed, train_time))
        #################################### RESNET BRANCH END ##########################################

    elif args.step == 3:
        #################################### FUSION BRANCH INIT ##########################################
        start_time = time.time()
        train_time = 0
        best_epoch = 0
        print("==> Start training of fusion branch")

        for p in model_dense.parameters():
            p.requires_grad = True

        for p in model_res.parameters():
            p.requires_grad = True

        for p in model_fusion.parameters():
            p.requires_grad = True

        for epoch in range(start_epoch, args.max_epoch):
            start_train_time = time.time()
            train('step3', model_dense, model_res, model_fusion, train_loader,
                  optimizer_dense, optimizer_res, optimizer_fusion, criterion,
                  args.print_freq, epoch, args.max_epoch, cfg,
                  data_transforms['train'])
            train_time += round(time.time() - start_train_time)
            if args.eval_step > 0 and (epoch + 1) % args.eval_step == 0 or (
                    epoch + 1) == args.max_epoch:
                print("==> Validation")
                loss_val = valid('step3', model_dense, model_res, model_fusion,
                                 valid_loader, criterion, args.print_freq,
                                 classes, cfg, data_transforms['valid'])

                if args.stepsize > 0:
                    if args.scheduler == 'ReduceLROnPlateau':
                        scheduler_dense.step(loss_val)
                        scheduler_res.step(loss_val)
                        scheduler_fusion.step(loss_val)
                    else:
                        scheduler_dense.step()
                        scheduler_res.step()
                        scheduler_fusion.step()

                is_best = loss_val < best_loss
                if is_best:
                    best_loss = loss_val
                    best_epoch = epoch + 1

                if use_gpu:
                    state_dict_dense = model_dense.module.state_dict()
                    state_dict_res = model_res.module.state_dict()
                    state_dict_fusion = model_fusion.module.state_dict()
                else:
                    state_dict_dense = model_dense.state_dict()
                    state_dict_res = model_res.state_dict()
                    state_dict_fusion = model_fusion.state_dict()

                save_checkpoint(
                    {
                        'state_dict': state_dict_dense,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'dense')
                save_checkpoint(
                    {
                        'state_dict': state_dict_res,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'res')
                save_checkpoint(
                    {
                        'state_dict': state_dict_fusion,
                        'loss': best_loss,
                        'epoch': epoch,
                    }, is_best, args.save_dir,
                    'checkpoint_ep' + str(epoch + 1) + '.pth.tar', 'fusion')

        print("==> Best Validation Loss {:.4%}, achieved at epoch {}".format(
            best_loss, best_epoch))

        elapsed = round(time.time() - start_time)
        elapsed = str(datetime.timedelta(seconds=elapsed))
        train_time = str(datetime.timedelta(seconds=train_time))
        print(
            "Fusion branch finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}."
            .format(elapsed, train_time))
        #################################### FUSION BRANCH END ##########################################

    else:
        print('args.step not found')
Пример #4
0
    classifier = Classifier(opts.latent_size).to(device)
    classer = CLASSIFIERS().to(device)

    print(cvae)
    print(dis)
    print(classifier)

    optimizer_cvae = torch.optim.Adam(cvae.parameters(),
                                      lr=opts.lr,
                                      betas=(opts.b1, opts.b2),
                                      weight_decay=opts.weight_decay)
    optimizer_dis = torch.optim.Adam(dis.parameters(),
                                     lr=opts.lr,
                                     betas=(opts.b1, opts.b2),
                                     weight_decay=opts.weight_decay)
    optimizer_classifier = torch.optim.Adam(classifier.parameters(),
                                            lr=opts.lr,
                                            betas=(opts.b1, opts.b2),
                                            weight_decay=opts.weight_decay)

    i = 1
    while os.path.isdir('./ex/' + str(i)):
        i += 1
    os.mkdir('./ex/' + str(i))
    output_path = './ex/' + str(i)

    losses = {
        'total': [],
        'kl': [],
        'bce': [],
        'dis': [],
Пример #5
0
def run_fl(args):
    with open(args.cfg_path) as f:
        cfg = edict(json.load(f))
        if args.verbose is True:
            print(json.dumps(cfg, indent=4))

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    if args.logtofile is True:
        logging.basicConfig(filename=args.save_path + '/log.txt',
                            filemode="w",
                            level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)

    if not args.resume:
        with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=1)

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    # initialise global model
    model = Classifier(cfg).to(device).train()

    if args.verbose is True:
        from torchsummary import summary
        if cfg.fix_ratio:
            h, w = cfg.long_side, cfg.long_side
        else:
            h, w = cfg.height, cfg.width
        summary(model.to(device), (3, h, w))

    if args.pre_train is not None:
        if os.path.exists(args.pre_train):
            ckpt = torch.load(args.pre_train, map_location=device)
            model.load_state_dict(ckpt)

    src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../'
    dst_folder = os.path.join(args.save_path, 'classification')
    rc, size = subprocess.getstatusoutput('du --max-depth=0 %s | cut -f1' %
                                          src_folder)

    if rc != 0:
        raise Exception('Copy folder error : {}'.format(rc))
    else:
        print('Successfully determined size of directory')

    rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' %
                                             (src_folder, dst_folder))
    if rc != 0:
        raise Exception('copy folder error : {}'.format(err_msg))
    else:
        print('Successfully copied folder')

    # copy train files
    train_files = cfg.train_csv
    clients = {}
    for i, c in enumerate(string.ascii_uppercase):
        if i < len(train_files):
            clients[c] = {}
        else:
            break

    # initialise clients
    for i, client in enumerate(clients):
        copyfile(train_files[i],
                 os.path.join(args.save_path, f'train_{client}.csv'))
        clients[client]['dataloader_train'] =\
            DataLoader(
                ImageDataset(train_files[i], cfg, mode='train'),
                batch_size=cfg.train_batch_size,
                num_workers=args.num_workers,drop_last=True,
                shuffle=True
            )
        clients[client]['bytes_uploaded'] = 0.0
        clients[client]['epoch'] = 0
    copyfile(cfg.dev_csv, os.path.join(args.save_path, 'dev.csv'))

    dataloader_dev = DataLoader(ImageDataset(cfg.dev_csv, cfg, mode='dev'),
                                batch_size=cfg.dev_batch_size,
                                num_workers=args.num_workers,
                                drop_last=False,
                                shuffle=False)
    dev_header = dataloader_dev.dataset._label_header

    w_global = model.state_dict()

    summary_train = {'epoch': 0, 'step': 0}
    summary_dev = {'loss': float('inf'), 'acc': 0.0}
    summary_writer = SummaryWriter(args.save_path)
    comm_rounds = cfg.epoch
    best_dict = {
        "acc_dev_best": 0.0,
        "auc_dev_best": 0.0,
        "loss_dev_best": float('inf'),
        "fused_dev_best": 0.0,
        "best_idx": 1
    }

    # Communication rounds loop
    for cr in range(comm_rounds):
        logging.info('{}, Start communication round {} of FL - {} ...'.format(
            time.strftime("%Y-%m-%d %H:%M:%S"), cr + 1, cfg.fl_technique))

        w_locals = []

        for client in clients:

            logging.info(
                '{}, Start local training process for client {}, communication round: {} ...'
                .format(time.strftime("%Y-%m-%d %H:%M:%S"), client, cr + 1))

            # Load previous current global model as start point
            model = Classifier(cfg).to(device).train()

            model.load_state_dict(w_global)

            if cfg.fl_technique == "FedProx":
                global_weight_collector = get_global_weights(model, device)
            else:
                global_weight_collector = None

            optimizer = get_optimizer(model.parameters(), cfg)

            # local training loops
            for epoch in range(cfg.local_epoch):
                lr = lr_schedule(cfg.lr, cfg.lr_factor, epoch, cfg.lr_epochs)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                summary_train, best_dict = train_epoch_fl(
                    summary_train, summary_dev, cfg, args, model,
                    clients[client]['dataloader_train'], dataloader_dev,
                    optimizer, summary_writer, best_dict, dev_header, epoch,
                    global_weight_collector)

                summary_train['step'] += 1

            bytes_to_upload = sys.getsizeof(model.state_dict())
            clients[client]['bytes_uploaded'] += bytes_to_upload
            logging.info(
                '{}, Completed local rounds for client {} in communication round {}. '
                'Uploading {} bytes to server, {} bytes in total sent from client'
                .format(time.strftime("%Y-%m-%d %H:%M:%S"), client, cr + 1,
                        bytes_to_upload, clients[client]['bytes_uploaded']))

            w_locals.append(model.state_dict())

        if cfg.fl_technique == "FedAvg":
            w_global = fed_avg(w_locals)
        elif cfg.fl_technique == 'WFedAvg':
            w_global = weighted_fed_avg(w_locals, cfg.train_proportions)
        elif cfg.fl_technique == 'FedProx':
            # Use weighted FedAvg when using FedProx
            w_global = weighted_fed_avg(w_locals, cfg.train_proportions)

        # Test the performance of the averaged model
        avged_model = Classifier(cfg).to(device)
        avged_model.load_state_dict(w_global)

        time_now = time.time()
        summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args,
                                                      avged_model,
                                                      dataloader_dev)
        time_spent = time.time() - time_now

        auclist = []
        for i in range(len(cfg.num_classes)):
            y_pred = predlist[i]
            y_true = true_list[i]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_pred,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            auclist.append(auc)
        auc_summary = np.array(auclist)

        loss_dev_str = ' '.join(
            map(lambda x: '{:.5f}'.format(x), summary_dev['loss']))
        acc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['acc']))
        auc_dev_str = ' '.join(map(lambda x: '{:.3f}'.format(x), auc_summary))

        logging.info(
            '{}, Averaged Model -> Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},'
            'Mean auc: {:.3f} '
            'Run Time : {:.2f} sec'.format(time.strftime("%Y-%m-%d %H:%M:%S"),
                                           summary_train['step'], loss_dev_str,
                                           acc_dev_str, auc_dev_str,
                                           auc_summary.mean(), time_spent))

        for t in range(len(cfg.num_classes)):
            summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]),
                                      summary_dev['loss'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]),
                                      summary_dev['acc'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]),
                                      auc_summary[t], summary_train['step'])

        save_best = False

        mean_acc = summary_dev['acc'][cfg.save_index].mean()
        if mean_acc >= best_dict['acc_dev_best']:
            best_dict['acc_dev_best'] = mean_acc
            if cfg.best_target == 'acc':
                save_best = True

        mean_auc = auc_summary[cfg.save_index].mean()
        if mean_auc >= best_dict['auc_dev_best']:
            best_dict['auc_dev_best'] = mean_auc
            if cfg.best_target == 'auc':
                save_best = True

        mean_loss = summary_dev['loss'][cfg.save_index].mean()
        if mean_loss <= best_dict['loss_dev_best']:
            best_dict['loss_dev_best'] = mean_loss
            if cfg.best_target == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'acc_dev_best': best_dict['acc_dev_best'],
                    'auc_dev_best': best_dict['auc_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': avged_model.state_dict()
                },
                os.path.join(args.save_path,
                             'best{}.ckpt'.format(best_dict['best_idx'])))

            best_dict['best_idx'] += 1
            if best_dict['best_idx'] > cfg.save_top_k:
                best_dict['best_idx'] = 1
            logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},'
                         'Auc :{},Best Auc : {:.3f}'.format(
                             time.strftime("%Y-%m-%d %H:%M:%S"),
                             summary_train['step'], loss_dev_str, acc_dev_str,
                             auc_dev_str, best_dict['auc_dev_best']))
        torch.save(
            {
                'epoch': cr,
                'step': summary_train['step'],
                'acc_dev_best': best_dict['acc_dev_best'],
                'auc_dev_best': best_dict['auc_dev_best'],
                'loss_dev_best': best_dict['loss_dev_best'],
                'state_dict': avged_model.state_dict()
            }, os.path.join(args.save_path, 'train.ckpt'))
Пример #6
0
device = torch.device(f"cuda:{device_ids[0]}")

model = Classifier(cfg)
if args.verbose:
    from torchsummary import summary
    h, w = (cfg.long_side, cfg.long_side) if cfg.fix_ratio \
           else (cfg.height, cfg.width)
    summary(model.to(device), (3, h, w))

model = DataParallel(model, device_ids=device_ids).to(device)
if args.pre_train is not None:
    if exists(args.pre_train):
        ckpt = torch.load(args.pre_train, map_location=device)
        model.module.load_state_dict(ckpt)

optimizer = get_optimizer(model.parameters(), cfg)

trainset = ImageDataset(cfg.train_csv, cfg, mode='train')
testset = ImageDataset(cfg.dev_csv, cfg, mode='val')

trainloader = DataLoader(trainset, batch_size=cfg.train_batch_size,
    num_workers=args.num_workers, drop_last=True, shuffle=True)
testloader = DataLoader(testset, batch_size=cfg.dev_batch_size,
    num_workers=args.num_workers, drop_last=False, shuffle=False)

dev_header = testloader.dataset._label_header

# Initialize parameters to log training output

summary_train = {'epoch': 0, 'step': 0}
summary_dev = {'loss': float('inf'), 'acc': 0.0}