示例#1
0
def build_model(cfg, paramsfile):
    model = Classifier(cfg)
    model = model.to('cpu')
    ckpt = torch.load(paramsfile, map_location='cpu')
    state_dict = ckpt['state_dict'] if 'state_dict' in ckpt else ckpt
    model.load_state_dict(state_dict)
    if 'step' in ckpt and 'auc_dev_best' in ckpt:
        print(f"Using model '{paramsfile}' at step: {ckpt['step']} "
              f"with AUC: {ckpt['auc_dev_best']}")
    return model.eval()
示例#2
0
def run(args):
    with open(args.cfg_path) as f:
        cfg = edict(json.load(f))
        if args.verbose is True:
            print(json.dumps(cfg, indent=4))

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    if args.logtofile is True:
        logging.basicConfig(filename=args.save_path + '/log.txt',
                            filemode="w",
                            level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)

    if not args.resume:
        with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=1)

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    model = Classifier(cfg)
    if args.verbose is True:
        from torchsummary import summary
        if cfg.fix_ratio:
            h, w = cfg.long_side, cfg.long_side
        else:
            h, w = cfg.height, cfg.width
        summary(model.to(device), (3, h, w))
    model = DataParallel(model, device_ids=device_ids).to(device).train()
    if args.pre_train is not None:
        if os.path.exists(args.pre_train):
            ckpt = torch.load(args.pre_train, map_location=device)
            model.module.load_state_dict(ckpt)
    optimizer = get_optimizer(model.parameters(), cfg)

    src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../'
    dst_folder = os.path.join(args.save_path, 'classification')
    # rc, size = subprocess.getstatusoutput('dir --max-depth=0 %s | cut -f1'
    #                                       % src_folder)
    # if rc != 0:
    #     print(size)
    #     raise Exception('Copy folder error : {}'.format(rc))
    # rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' % (src_folder,
    #                                                           dst_folder))
    # if rc != 0:
    #     raise Exception('copy folder error : {}'.format(err_msg))

    copyfile(cfg.train_csv, os.path.join(args.save_path, 'train.csv'))
    copyfile(cfg.dev_csv, os.path.join(args.save_path, 'valid.csv'))

    dataloader_train = DataLoader(ImageDataset(cfg.train_csv,
                                               cfg,
                                               mode='train'),
                                  batch_size=cfg.train_batch_size,
                                  num_workers=args.num_workers,
                                  drop_last=True,
                                  shuffle=True)
    dataloader_dev = DataLoader(ImageDataset(cfg.dev_csv, cfg, mode='dev'),
                                batch_size=cfg.dev_batch_size,
                                num_workers=args.num_workers,
                                drop_last=False,
                                shuffle=False)
    dev_header = dataloader_dev.dataset._label_header

    summary_train = {'epoch': 0, 'step': 0}
    summary_dev = {'loss': float('inf'), 'acc': 0.0}
    summary_writer = SummaryWriter(args.save_path)
    epoch_start = 0
    best_dict = {
        "acc_dev_best": 0.0,
        "auc_dev_best": 0.0,
        "loss_dev_best": float('inf'),
        "fused_dev_best": 0.0,
        "best_idx": 1
    }

    if args.resume:
        ckpt_path = os.path.join(args.save_path, 'train.ckpt')
        ckpt = torch.load(ckpt_path, map_location=device)
        model.module.load_state_dict(ckpt['state_dict'])
        summary_train = {'epoch': ckpt['epoch'], 'step': ckpt['step']}
        best_dict['acc_dev_best'] = ckpt['acc_dev_best']
        best_dict['loss_dev_best'] = ckpt['loss_dev_best']
        best_dict['auc_dev_best'] = ckpt['auc_dev_best']
        epoch_start = ckpt['epoch']

    for epoch in range(epoch_start, cfg.epoch):
        lr = lr_schedule(cfg.lr, cfg.lr_factor, summary_train['epoch'],
                         cfg.lr_epochs)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        summary_train, best_dict = train_epoch(summary_train, summary_dev, cfg,
                                               args, model, dataloader_train,
                                               dataloader_dev, optimizer,
                                               summary_writer, best_dict,
                                               dev_header)

        time_now = time.time()
        summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args,
                                                      model, dataloader_dev)
        time_spent = time.time() - time_now

        auclist = []
        for i in range(len(cfg.num_classes)):
            y_pred = predlist[i]
            y_true = true_list[i]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_pred,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            auclist.append(auc)
        summary_dev['auc'] = np.array(auclist)

        loss_dev_str = ' '.join(
            map(lambda x: '{:.5f}'.format(x), summary_dev['loss']))
        acc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['acc']))
        auc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['auc']))

        logging.info('{}, Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},'
                     'Mean auc: {:.3f} '
                     'Run Time : {:.2f} sec'.format(
                         time.strftime("%Y-%m-%d %H:%M:%S"),
                         summary_train['step'], loss_dev_str, acc_dev_str,
                         auc_dev_str, summary_dev['auc'].mean(), time_spent))

        for t in range(len(cfg.num_classes)):
            summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]),
                                      summary_dev['loss'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]),
                                      summary_dev['acc'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]),
                                      summary_dev['auc'][t],
                                      summary_train['step'])

        save_best = False

        mean_acc = summary_dev['acc'][cfg.save_index].mean()
        if mean_acc >= best_dict['acc_dev_best']:
            best_dict['acc_dev_best'] = mean_acc
            if cfg.best_target == 'acc':
                save_best = True

        mean_auc = summary_dev['auc'][cfg.save_index].mean()
        if mean_auc >= best_dict['auc_dev_best']:
            best_dict['auc_dev_best'] = mean_auc
            if cfg.best_target == 'auc':
                save_best = True

        mean_loss = summary_dev['loss'][cfg.save_index].mean()
        if mean_loss <= best_dict['loss_dev_best']:
            best_dict['loss_dev_best'] = mean_loss
            if cfg.best_target == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'acc_dev_best': best_dict['acc_dev_best'],
                    'auc_dev_best': best_dict['auc_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': model.module.state_dict()
                },
                os.path.join(args.save_path,
                             'best{}.ckpt'.format(best_dict['best_idx'])))
            best_dict['best_idx'] += 1
            if best_dict['best_idx'] > cfg.save_top_k:
                best_dict['best_idx'] = 1
            logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},'
                         'Auc :{},Best Auc : {:.3f}'.format(
                             time.strftime("%Y-%m-%d %H:%M:%S"),
                             summary_train['step'], loss_dev_str, acc_dev_str,
                             auc_dev_str, best_dict['auc_dev_best']))
        torch.save(
            {
                'epoch': summary_train['epoch'],
                'step': summary_train['step'],
                'acc_dev_best': best_dict['acc_dev_best'],
                'auc_dev_best': best_dict['auc_dev_best'],
                'loss_dev_best': best_dict['loss_dev_best'],
                'state_dict': model.module.state_dict()
            }, os.path.join(args.save_path, 'train.ckpt'))
    summary_writer.close()
示例#3
0
def run(args, val_h5_file):
    with open(args.cfg_path) as f:
        cfg = edict(json.load(f))
        if args.verbose is True:
            print(json.dumps(cfg, indent=4))

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    if args.logtofile is True:
        logging.basicConfig(filename=args.save_path + '/log.txt',
                            filemode="w",
                            level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)

    if not args.resume:
        with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=1)

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    model = Classifier(cfg)
    if args.verbose is True:
        from torchsummary import summary
        if cfg.fix_ratio:
            h, w = cfg.long_side, cfg.long_side
        else:
            h, w = cfg.height, cfg.width
        summary(model.to(device), (3, h, w))
    model = DataParallel(model, device_ids=device_ids).to(device).train()
    if args.pre_train is not None:
        if os.path.exists(args.pre_train):
            ckpt = torch.load(args.pre_train, map_location=device)
            model.module.load_state_dict(ckpt)
    optimizer = get_optimizer(model.parameters(), cfg)

    #src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../'
    #dst_folder = os.path.join(args.save_path, 'classification')
    #rc, size = subprocess.getstatusoutput('du --max-depth=0 %s | cut -f1' % src_folder)
    #if rc != 0: raise Exception('Copy folder error : {}'.format(rc))
    #rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' % (src_folder, dst_folder))
    #if rc != 0: raise Exception('copy folder error : {}'.format(err_msg))
    #copyfile(cfg.train_csv, os.path.join(args.save_path, 'train.csv'))
    #copyfile(cfg.dev_csv, os.path.join(args.save_path, 'dev.csv'))
    # np_train_h5_file = np.array(train_h5_file['train'][:10000], dtype=np.uint8)
    # np_t_u_ones = np.array(train_h5_file['train_u_ones'][:10000], dtype=np.int8)
    # np_t_u_zeros = np.array(train_h5_file['train_u_zeros'][:10000], dtype=np.int8)
    # np_t_u_random = np.array(train_h5_file['train_u_random'][:10000], dtype=np.int8)

    np_val_h5_file = np.array(val_h5_file['val'], dtype=np.uint8)
    np_v_u_ones = np.array(val_h5_file['val_u_ones'], dtype=np.int8)
    np_v_u_zeros = np.array(val_h5_file['val_u_zeros'], dtype=np.int8)
    np_v_u_random = np.array(val_h5_file['val_u_random'], dtype=np.int8)

    train_labels = {}
    with h5py.File(f'{args.train_chunks}/train_labels.h5', 'r') as fp:
        train_labels['train_u_ones'] = np.array(fp['train_u_ones'],
                                                dtype=np.int8)
        train_labels['train_u_zeros'] = np.array(fp['train_u_zeros'],
                                                 dtype=np.int8)
        train_labels['train_u_random'] = np.array(fp['train_u_random'],
                                                  dtype=np.int8)
    np_train_samples = None
    for i in range(args.chunk_count):
        with open(f'{args.train_chunks}/chexpert_dset_chunk_{i+1}.npy',
                  'rb') as f:
            if np_train_samples is None:
                np_train_samples = np.load(f)
            else:
                np_train_samples = np.concatenate(
                    (np_train_samples, np.load(f)))

    dataloader_train = DataLoader(ImageDataset(
        [np_train_samples, train_labels], cfg, mode='train'),
                                  batch_size=cfg.train_batch_size,
                                  num_workers=args.num_workers,
                                  drop_last=True,
                                  shuffle=True)

    dataloader_dev = DataLoader(ImageDataset(
        [np_val_h5_file, np_v_u_zeros, np_v_u_ones, np_v_u_random],
        cfg,
        mode='val'),
                                batch_size=cfg.dev_batch_size,
                                num_workers=args.num_workers,
                                drop_last=False,
                                shuffle=False)
    #dev_header = dataloader_dev.dataset._label_header
    dev_header = [
        'No_Finding', 'Enlarged_Cardiomediastinum', 'Cardiomegaly',
        'Lung_Opacity', 'Lung_Lesion', 'Edema', 'Consolidation', 'Pneumonia',
        'Atelectasis', 'Pneumothorax', 'Pleural_Effusion', 'Pleural_Other',
        'Fracture', 'Support_Devices'
    ]
    print(f'dataloaders are set. train count: {np_train_samples.shape[0]}')
    logging.info("[LOGGING TEST]: dataloaders are set...")
    summary_train = {'epoch': 0, 'step': 0}
    summary_dev = {'loss': float('inf'), 'acc': 0.0}
    summary_writer = SummaryWriter(args.save_path)
    epoch_start = 0
    best_dict = {
        "acc_dev_best": 0.0,
        "auc_dev_best": 0.0,
        "loss_dev_best": float('inf'),
        "fused_dev_best": 0.0,
        "best_idx": 1
    }

    if args.resume:
        ckpt_path = os.path.join(args.save_path, 'train.ckpt')
        ckpt = torch.load(ckpt_path, map_location=device)
        model.module.load_state_dict(ckpt['state_dict'])
        summary_train = {'epoch': ckpt['epoch'], 'step': ckpt['step']}
        best_dict['acc_dev_best'] = ckpt['acc_dev_best']
        best_dict['loss_dev_best'] = ckpt['loss_dev_best']
        best_dict['auc_dev_best'] = ckpt['auc_dev_best']
        epoch_start = ckpt['epoch']

    q_list = []
    k_list = []
    for i in range(len(cfg.num_classes)):
        q_list.append(args.q)
        k_list.append(args.k)

    k_list = torch.FloatTensor(k_list)
    q_list = torch.FloatTensor(q_list)
    loss_sq_hinge = MultiClassSquaredHingeLoss()
    print('Everything is set starting to train...')
    before = datetime.datetime.now()
    for epoch in range(epoch_start, cfg.epoch):
        lr = lr_schedule(cfg.lr, cfg.lr_factor, summary_train['epoch'],
                         cfg.lr_epochs)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        summary_train, best_dict = train_epoch(summary_train, summary_dev, cfg,
                                               args, model, dataloader_train,
                                               dataloader_dev, optimizer,
                                               summary_writer, best_dict,
                                               dev_header, q_list, k_list,
                                               loss_sq_hinge)

        time_now = time.time()
        summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args,
                                                      model, dataloader_dev,
                                                      q_list, k_list,
                                                      loss_sq_hinge)
        time_spent = time.time() - time_now

        auclist = []
        for i in range(len(cfg.num_classes)):
            y_pred = predlist[i]
            y_true = true_list[i]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_pred,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            auclist.append(auc)
        summary_dev['auc'] = np.array(auclist)

        loss_dev_str = ' '.join(
            map(lambda x: '{:.5f}'.format(x), summary_dev['loss']))
        acc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['acc']))
        auc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['auc']))

        logging.info('{}, Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},'
                     'Mean auc: {:.3f} '
                     'Run Time : {:.2f} sec'.format(
                         time.strftime("%Y-%m-%d %H:%M:%S"),
                         summary_train['step'], loss_dev_str, acc_dev_str,
                         auc_dev_str, summary_dev['auc'].mean(), time_spent))

        for t in range(len(cfg.num_classes)):
            summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]),
                                      summary_dev['loss'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]),
                                      summary_dev['acc'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]),
                                      summary_dev['auc'][t],
                                      summary_train['step'])

        save_best = False

        mean_acc = summary_dev['acc'][cfg.save_index].mean()
        if mean_acc >= best_dict['acc_dev_best']:
            best_dict['acc_dev_best'] = mean_acc
            if cfg.best_target == 'acc':
                save_best = True

        mean_auc = summary_dev['auc'][cfg.save_index].mean()
        if mean_auc >= best_dict['auc_dev_best']:
            best_dict['auc_dev_best'] = mean_auc
            if cfg.best_target == 'auc':
                save_best = True

        mean_loss = summary_dev['loss'][cfg.save_index].mean()
        if mean_loss <= best_dict['loss_dev_best']:
            best_dict['loss_dev_best'] = mean_loss
            if cfg.best_target == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'acc_dev_best': best_dict['acc_dev_best'],
                    'auc_dev_best': best_dict['auc_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': model.module.state_dict()
                },
                os.path.join(args.save_path,
                             'best{}.ckpt'.format(best_dict['best_idx'])))
            best_dict['best_idx'] += 1
            if best_dict['best_idx'] > cfg.save_top_k:
                best_dict['best_idx'] = 1
            logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},'
                         'Auc :{},Best Auc : {:.3f}'.format(
                             time.strftime("%Y-%m-%d %H:%M:%S"),
                             summary_train['step'], loss_dev_str, acc_dev_str,
                             auc_dev_str, best_dict['auc_dev_best']))
        torch.save(
            {
                'epoch': summary_train['epoch'],
                'step': summary_train['step'],
                'acc_dev_best': best_dict['acc_dev_best'],
                'auc_dev_best': best_dict['auc_dev_best'],
                'loss_dev_best': best_dict['loss_dev_best'],
                'state_dict': model.module.state_dict()
            }, os.path.join(args.save_path, 'train.ckpt'))

        print_remaining_time(before,
                             epoch + 1,
                             cfg.epoch,
                             additional='[training]')
    summary_writer.close()
示例#4
0
def run_fl(args):
    with open(args.cfg_path) as f:
        cfg = edict(json.load(f))
        if args.verbose is True:
            print(json.dumps(cfg, indent=4))

    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    if args.logtofile is True:
        logging.basicConfig(filename=args.save_path + '/log.txt',
                            filemode="w",
                            level=logging.INFO)
    else:
        logging.basicConfig(level=logging.INFO)

    if not args.resume:
        with open(os.path.join(args.save_path, 'cfg.json'), 'w') as f:
            json.dump(cfg, f, indent=1)

    device_ids = list(map(int, args.device_ids.split(',')))
    num_devices = torch.cuda.device_count()
    if num_devices < len(device_ids):
        raise Exception('#available gpu : {} < --device_ids : {}'.format(
            num_devices, len(device_ids)))
    device = torch.device('cuda:{}'.format(device_ids[0]))

    # initialise global model
    model = Classifier(cfg).to(device).train()

    if args.verbose is True:
        from torchsummary import summary
        if cfg.fix_ratio:
            h, w = cfg.long_side, cfg.long_side
        else:
            h, w = cfg.height, cfg.width
        summary(model.to(device), (3, h, w))

    if args.pre_train is not None:
        if os.path.exists(args.pre_train):
            ckpt = torch.load(args.pre_train, map_location=device)
            model.load_state_dict(ckpt)

    src_folder = os.path.dirname(os.path.abspath(__file__)) + '/../'
    dst_folder = os.path.join(args.save_path, 'classification')
    rc, size = subprocess.getstatusoutput('du --max-depth=0 %s | cut -f1' %
                                          src_folder)

    if rc != 0:
        raise Exception('Copy folder error : {}'.format(rc))
    else:
        print('Successfully determined size of directory')

    rc, err_msg = subprocess.getstatusoutput('cp -R %s %s' %
                                             (src_folder, dst_folder))
    if rc != 0:
        raise Exception('copy folder error : {}'.format(err_msg))
    else:
        print('Successfully copied folder')

    # copy train files
    train_files = cfg.train_csv
    clients = {}
    for i, c in enumerate(string.ascii_uppercase):
        if i < len(train_files):
            clients[c] = {}
        else:
            break

    # initialise clients
    for i, client in enumerate(clients):
        copyfile(train_files[i],
                 os.path.join(args.save_path, f'train_{client}.csv'))
        clients[client]['dataloader_train'] =\
            DataLoader(
                ImageDataset(train_files[i], cfg, mode='train'),
                batch_size=cfg.train_batch_size,
                num_workers=args.num_workers,drop_last=True,
                shuffle=True
            )
        clients[client]['bytes_uploaded'] = 0.0
        clients[client]['epoch'] = 0
    copyfile(cfg.dev_csv, os.path.join(args.save_path, 'dev.csv'))

    dataloader_dev = DataLoader(ImageDataset(cfg.dev_csv, cfg, mode='dev'),
                                batch_size=cfg.dev_batch_size,
                                num_workers=args.num_workers,
                                drop_last=False,
                                shuffle=False)
    dev_header = dataloader_dev.dataset._label_header

    w_global = model.state_dict()

    summary_train = {'epoch': 0, 'step': 0}
    summary_dev = {'loss': float('inf'), 'acc': 0.0}
    summary_writer = SummaryWriter(args.save_path)
    comm_rounds = cfg.epoch
    best_dict = {
        "acc_dev_best": 0.0,
        "auc_dev_best": 0.0,
        "loss_dev_best": float('inf'),
        "fused_dev_best": 0.0,
        "best_idx": 1
    }

    # Communication rounds loop
    for cr in range(comm_rounds):
        logging.info('{}, Start communication round {} of FL - {} ...'.format(
            time.strftime("%Y-%m-%d %H:%M:%S"), cr + 1, cfg.fl_technique))

        w_locals = []

        for client in clients:

            logging.info(
                '{}, Start local training process for client {}, communication round: {} ...'
                .format(time.strftime("%Y-%m-%d %H:%M:%S"), client, cr + 1))

            # Load previous current global model as start point
            model = Classifier(cfg).to(device).train()

            model.load_state_dict(w_global)

            if cfg.fl_technique == "FedProx":
                global_weight_collector = get_global_weights(model, device)
            else:
                global_weight_collector = None

            optimizer = get_optimizer(model.parameters(), cfg)

            # local training loops
            for epoch in range(cfg.local_epoch):
                lr = lr_schedule(cfg.lr, cfg.lr_factor, epoch, cfg.lr_epochs)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                summary_train, best_dict = train_epoch_fl(
                    summary_train, summary_dev, cfg, args, model,
                    clients[client]['dataloader_train'], dataloader_dev,
                    optimizer, summary_writer, best_dict, dev_header, epoch,
                    global_weight_collector)

                summary_train['step'] += 1

            bytes_to_upload = sys.getsizeof(model.state_dict())
            clients[client]['bytes_uploaded'] += bytes_to_upload
            logging.info(
                '{}, Completed local rounds for client {} in communication round {}. '
                'Uploading {} bytes to server, {} bytes in total sent from client'
                .format(time.strftime("%Y-%m-%d %H:%M:%S"), client, cr + 1,
                        bytes_to_upload, clients[client]['bytes_uploaded']))

            w_locals.append(model.state_dict())

        if cfg.fl_technique == "FedAvg":
            w_global = fed_avg(w_locals)
        elif cfg.fl_technique == 'WFedAvg':
            w_global = weighted_fed_avg(w_locals, cfg.train_proportions)
        elif cfg.fl_technique == 'FedProx':
            # Use weighted FedAvg when using FedProx
            w_global = weighted_fed_avg(w_locals, cfg.train_proportions)

        # Test the performance of the averaged model
        avged_model = Classifier(cfg).to(device)
        avged_model.load_state_dict(w_global)

        time_now = time.time()
        summary_dev, predlist, true_list = test_epoch(summary_dev, cfg, args,
                                                      avged_model,
                                                      dataloader_dev)
        time_spent = time.time() - time_now

        auclist = []
        for i in range(len(cfg.num_classes)):
            y_pred = predlist[i]
            y_true = true_list[i]
            fpr, tpr, thresholds = metrics.roc_curve(y_true,
                                                     y_pred,
                                                     pos_label=1)
            auc = metrics.auc(fpr, tpr)
            auclist.append(auc)
        auc_summary = np.array(auclist)

        loss_dev_str = ' '.join(
            map(lambda x: '{:.5f}'.format(x), summary_dev['loss']))
        acc_dev_str = ' '.join(
            map(lambda x: '{:.3f}'.format(x), summary_dev['acc']))
        auc_dev_str = ' '.join(map(lambda x: '{:.3f}'.format(x), auc_summary))

        logging.info(
            '{}, Averaged Model -> Dev, Step : {}, Loss : {}, Acc : {}, Auc : {},'
            'Mean auc: {:.3f} '
            'Run Time : {:.2f} sec'.format(time.strftime("%Y-%m-%d %H:%M:%S"),
                                           summary_train['step'], loss_dev_str,
                                           acc_dev_str, auc_dev_str,
                                           auc_summary.mean(), time_spent))

        for t in range(len(cfg.num_classes)):
            summary_writer.add_scalar('dev/loss_{}'.format(dev_header[t]),
                                      summary_dev['loss'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/acc_{}'.format(dev_header[t]),
                                      summary_dev['acc'][t],
                                      summary_train['step'])
            summary_writer.add_scalar('dev/auc_{}'.format(dev_header[t]),
                                      auc_summary[t], summary_train['step'])

        save_best = False

        mean_acc = summary_dev['acc'][cfg.save_index].mean()
        if mean_acc >= best_dict['acc_dev_best']:
            best_dict['acc_dev_best'] = mean_acc
            if cfg.best_target == 'acc':
                save_best = True

        mean_auc = auc_summary[cfg.save_index].mean()
        if mean_auc >= best_dict['auc_dev_best']:
            best_dict['auc_dev_best'] = mean_auc
            if cfg.best_target == 'auc':
                save_best = True

        mean_loss = summary_dev['loss'][cfg.save_index].mean()
        if mean_loss <= best_dict['loss_dev_best']:
            best_dict['loss_dev_best'] = mean_loss
            if cfg.best_target == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'acc_dev_best': best_dict['acc_dev_best'],
                    'auc_dev_best': best_dict['auc_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': avged_model.state_dict()
                },
                os.path.join(args.save_path,
                             'best{}.ckpt'.format(best_dict['best_idx'])))

            best_dict['best_idx'] += 1
            if best_dict['best_idx'] > cfg.save_top_k:
                best_dict['best_idx'] = 1
            logging.info('{}, Best, Step : {}, Loss : {}, Acc : {},'
                         'Auc :{},Best Auc : {:.3f}'.format(
                             time.strftime("%Y-%m-%d %H:%M:%S"),
                             summary_train['step'], loss_dev_str, acc_dev_str,
                             auc_dev_str, best_dict['auc_dev_best']))
        torch.save(
            {
                'epoch': cr,
                'step': summary_train['step'],
                'acc_dev_best': best_dict['acc_dev_best'],
                'auc_dev_best': best_dict['auc_dev_best'],
                'loss_dev_best': best_dict['loss_dev_best'],
                'state_dict': avged_model.state_dict()
            }, os.path.join(args.save_path, 'train.ckpt'))
示例#5
0
    with open(join(args.save_path, 'cfg.json'), 'w') as f:
        json.dump(cfg, f, indent=1)

device_ids = list(map(int, args.device_ids.split(',')))
num_devices = torch.cuda.device_count()
assert num_devices >= len(device_ids), f"""
#available gpu : {num_devices} < --device_ids : {len(device_ids)}"""

device = torch.device(f"cuda:{device_ids[0]}")

model = Classifier(cfg)
if args.verbose:
    from torchsummary import summary
    h, w = (cfg.long_side, cfg.long_side) if cfg.fix_ratio \
           else (cfg.height, cfg.width)
    summary(model.to(device), (3, h, w))

model = DataParallel(model, device_ids=device_ids).to(device)
if args.pre_train is not None:
    if exists(args.pre_train):
        ckpt = torch.load(args.pre_train, map_location=device)
        model.module.load_state_dict(ckpt)

optimizer = get_optimizer(model.parameters(), cfg)

trainset = ImageDataset(cfg.train_csv, cfg, mode='train')
testset = ImageDataset(cfg.dev_csv, cfg, mode='val')

trainloader = DataLoader(trainset, batch_size=cfg.train_batch_size,
    num_workers=args.num_workers, drop_last=True, shuffle=True)
testloader = DataLoader(testset, batch_size=cfg.dev_batch_size,
示例#6
0
with open(args.model_path+'cfg.json') as f:
    cfg = edict(json.load(f))

model_file = "model/best.pth"
device = torch.device('cpu')  # PyTorch v0.4.0
net = Classifier(cfg)
ckpt = torch.load("model/best.ckpt")
net.load_state_dict(ckpt['state_dict'], strict=False)
torch.save(net, model_file)

net.eval()

dummy_input = torch.ones([1, 3, 1024, 1024])

net.to(device)
output = net(dummy_input)

device = torch.device("cuda")  # PyTorch v0.4.0
summary(net.to(device), (3, 1024, 1024))

pytorch_parser = PytorchParser(model_file, [3, 1024, 1024])
#
pytorch_parser.run(model_file)

Model_FILE = model_file + '.prototxt'

PRETRAINED = model_file + '.caffemodel'

net = caffe.Classifier(Model_FILE, PRETRAINED)