Ejemplo n.º 1
0
def train_load(args):
    """
    Loads the training data.
    :param args:
    :return: train_loader, val_loader, data_config, train_inputs
    """
    filelist = sorted(sum([glob.glob(f) for f in args.data_train], []))
    # np.random.seed(1)
    np.random.shuffle(filelist)
    if args.demo:
        filelist = filelist[:20]
        _logger.info(filelist)
        args.data_fraction = 0.1
        args.fetch_step = 0.002
    num_workers = min(args.num_workers,
                      int(len(filelist) * args.file_fraction))
    train_data = SimpleIterDataset(
        filelist,
        args.data_config,
        for_training=True,
        load_range_and_fraction=((0, args.train_val_split),
                                 args.data_fraction),
        file_fraction=args.file_fraction,
        fetch_by_files=args.fetch_by_files,
        fetch_step=args.fetch_step)
    val_data = SimpleIterDataset(
        filelist,
        args.data_config,
        for_training=True,
        load_range_and_fraction=((args.train_val_split, 1),
                                 args.data_fraction),
        file_fraction=args.file_fraction,
        fetch_by_files=args.fetch_by_files,
        fetch_step=args.fetch_step)
    train_loader = DataLoader(train_data,
                              num_workers=num_workers,
                              batch_size=args.batch_size,
                              drop_last=True,
                              pin_memory=True)
    val_loader = DataLoader(val_data,
                            num_workers=num_workers,
                            batch_size=args.batch_size,
                            drop_last=True,
                            pin_memory=True)
    data_config = train_data.config
    train_input_names = train_data.config.input_names
    train_label_names = train_data.config.label_names

    return train_loader, val_loader, data_config, train_input_names, train_label_names
Ejemplo n.º 2
0
def test_load(args):
    """
    Loads the test data.
    :param args:
    :return: test_loaders, data_config
    """
    # keyword-based --data-test: 'a:/path/to/a b:/path/to/b'
    # split --data-test: 'a%10:/path/to/a/*'
    file_dict = {}
    split_dict = {}
    for f in args.data_test:
        if ':' in f:
            name, fp = f.split(':')
            if '%' in name:
                name, split = name.split('%')
                split_dict[name] = int(split)
        else:
            name, fp = '', f
        files = glob.glob(fp)
        if name in file_dict:
            file_dict[name] += files
        else:
            file_dict[name] = files

    # sort files
    for name, files in file_dict.items():
        file_dict[name] = sorted(files)

    # apply splitting
    for name, split in split_dict.items():
        files = file_dict.pop(name)
        for i in range((len(files) + split - 1) // split):
            file_dict[f'{name}_{i}'] = files[i * split:(i + 1) * split]

    def get_test_loader(name):
        filelist = file_dict[name]
        _logger.info('Running on test file group %s with %d files:\n...%s',
                     name, len(filelist), '\n...'.join(filelist))
        num_workers = min(args.num_workers, len(filelist))
        test_data = SimpleIterDataset(
            filelist,
            args.data_config,
            for_training=False,
            load_range_and_fraction=((0, 1), args.data_fraction),
            fetch_by_files=True,
            fetch_step=1)
        test_loader = DataLoader(test_data,
                                 num_workers=num_workers,
                                 batch_size=args.batch_size,
                                 drop_last=False,
                                 pin_memory=True)
        return test_loader

    test_loaders = {
        name: functools.partial(get_test_loader, name)
        for name in file_dict
    }
    data_config = SimpleIterDataset([], args.data_config,
                                    for_training=False).config
    return test_loaders, data_config
Ejemplo n.º 3
0
 def get_test_loader(name):
     filelist = file_dict[name]
     _logger.info('Running on test file group %s with %d files:\n...%s',
                  name, len(filelist), '\n...'.join(filelist))
     num_workers = min(args.num_workers, len(filelist))
     test_data = SimpleIterDataset(
         filelist,
         args.data_config,
         for_training=False,
         load_range_and_fraction=((0, 1), args.data_fraction),
         fetch_by_files=True,
         fetch_step=1)
     test_loader = DataLoader(test_data,
                              num_workers=num_workers,
                              batch_size=args.batch_size,
                              drop_last=False,
                              pin_memory=True)
     return test_loader
Ejemplo n.º 4
0
def test_load(args):
    """
    Loads the test data.
    :param args:
    :return: test_loader, data_config
    """
    filelist = sorted(sum([glob.glob(f) for f in args.data_test], []))
    num_workers = min(args.num_workers, len(filelist))
    test_data = SimpleIterDataset(filelist,
                                  args.data_config,
                                  for_training=False,
                                  load_range_and_fraction=((0, 1),
                                                           args.data_fraction),
                                  fetch_by_files=True,
                                  fetch_step=1)
    test_loader = DataLoader(test_data,
                             num_workers=num_workers,
                             batch_size=args.batch_size,
                             drop_last=False,
                             pin_memory=True)
    data_config = test_data.config
    return test_loader, data_config
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c',
                        '--data-config',
                        type=str,
                        default='data/ak15_points_pf_sv_v0.yaml',
                        help='data config YAML file')
    parser.add_argument('-i',
                        '--data-train',
                        nargs='*',
                        default=[],
                        help='training files')
    parser.add_argument('-t',
                        '--data-test',
                        nargs='*',
                        default=[],
                        help='testing files')
    parser.add_argument(
        '--data-fraction',
        type=float,
        default=1,
        help=
        'fraction of events to load from each file; for training, the events are randomly selected for each epoch'
    )
    parser.add_argument(
        '--data-dilation',
        type=int,
        default=1,
        help=
        'reduce number of file by a factor of `d` for training. NOT recommended in general - use `--data-fraction` instead.'
    )
    parser.add_argument(
        '--files-per-fetch',
        type=int,
        default=20,
        help=
        'number of files to load each time; shuffling is done within these events, so choose a number large enough to get events from all classes'
    )
    parser.add_argument('--train-val-split',
                        type=float,
                        default=0.8,
                        help='training/validation split fraction')
    parser.add_argument(
        '--demo',
        action='store_true',
        default=False,
        help=
        'quickly test the setup by running over only a small number of events')
    parser.add_argument(
        '--lr-finder',
        type=str,
        default=None,
        help=
        'run learning rate finder instead of the actual training; format: ``start_lr, end_lr, num_iters``'
    )
    parser.add_argument(
        '-n',
        '--network-config',
        type=str,
        default='networks/particle_net_pfcand_sv.py',
        help=
        'network architecture configuration file; the path must be relative to the current dir'
    )
    parser.add_argument(
        '--network-option',
        nargs=2,
        action='append',
        default=[],
        help=
        'options to pass to the model class constructor, e.g., `--network-option use_counts False`'
    )
    parser.add_argument(
        '-m',
        '--model-prefix',
        type=str,
        default='test_output/model_name',
        help=
        'path to save or load the model; for training, this will be used as a prefix; for testing, this should be the full path including extension'
    )
    parser.add_argument('--num-epochs',
                        type=int,
                        default=20,
                        help='number of epochs')
    parser.add_argument(
        '--optimizer',
        type=str,
        default='ranger',
        choices=['adam', 'ranger'],  # TODO: add more
        help='optimizer for the training')
    parser.add_argument(
        '--load-epoch',
        type=int,
        default=None,
        help=
        'used to resume interrupted training, load model and optimizer state saved in the `epoch-%d_state.pt` and `epoch-%d_optimizer.pt` files'
    )
    parser.add_argument('--start-lr',
                        type=float,
                        default=5e-3,
                        help='start learning rate')
    parser.add_argument(
        '--lr-steps',
        type=str,
        default='10,20',
        help=
        'steps to reduce the lr; currently only used when setting `--optimizer` to adam'
    )
    parser.add_argument('--batch-size',
                        type=int,
                        default=128,
                        help='batch size')
    parser.add_argument(
        '--use-amp',
        action='store_true',
        default=False,
        help='use mixed precision training (fp16); NOT WORKING YET')
    parser.add_argument(
        '--gpus',
        type=str,
        default='0',
        help='device for the training/testing; to use CPU, set to empty string ('
        '); to use multiple gpu, set it as a comma separated list, e.g., `1,2,3,4`'
    )
    parser.add_argument(
        '--num-workers',
        type=int,
        default=2,
        help=
        'number of threads to load the dataset; memory consuption and disk access load increases (~linearly) with this numbers'
    )
    parser.add_argument('--predict',
                        action='store_true',
                        default=False,
                        help='run prediction instead of training')
    parser.add_argument(
        '--predict-output',
        type=str,
        help=
        'path to save the prediction output, support `.root` and `.awkd` format'
    )
    parser.add_argument(
        '--export-onnx',
        type=str,
        default=None,
        help=
        'export the PyTorch model to ONNX model and save it at the given path (path must ends w/ .onnx); '
        'needs to set `--data-config`, `--network-config`, and `--model-prefix` (requires the full model path)'
    )

    args = parser.parse_args()
    _logger.info(args)

    if args.use_amp:
        raise NotImplementedError


#         from apex import amp

    if args.data_dilation > 1:
        _logger.warning(
            'Use of `data-dilation` is not recomended in general -- consider using `data-fraction` instead.'
        )

    # training/testing mode
    training_mode = not args.predict

    # device
    if args.gpus:
        gpus = [int(i) for i in args.gpus.split(',')]
        dev = torch.device(gpus[0])
    else:
        gpus = None
        dev = torch.device('cpu')

    # load data
    if training_mode:
        filelist = sorted(sum([glob.glob(f) for f in args.data_train], []))
        # np.random.seed(1)
        np.random.shuffle(filelist)
        if args.demo:
            filelist = filelist[:20]
            _logger.info(filelist)
            args.data_fraction = 0.1
            args.files_per_fetch = 5
        train_data = SimpleIterDataset(filelist,
                                       args.data_config,
                                       for_training=True,
                                       partial_load=((0, args.train_val_split),
                                                     args.data_fraction),
                                       dilation=args.data_dilation,
                                       files_per_fetch=args.files_per_fetch)
        val_data = SimpleIterDataset(filelist,
                                     args.data_config,
                                     for_training=True,
                                     partial_load=((args.train_val_split, 1),
                                                   args.data_fraction),
                                     dilation=args.data_dilation,
                                     files_per_fetch=args.files_per_fetch)
        train_loader = DataLoader(train_data,
                                  num_workers=args.num_workers,
                                  batch_size=args.batch_size,
                                  drop_last=True,
                                  pin_memory=True)
        val_loader = DataLoader(val_data,
                                num_workers=args.num_workers,
                                batch_size=args.batch_size,
                                drop_last=True,
                                pin_memory=True)
        data_config = train_data.config
    else:
        filelist = sorted(sum([glob.glob(f) for f in args.data_test], []))
        test_data = SimpleIterDataset(filelist,
                                      args.data_config,
                                      for_training=False,
                                      files_per_fetch=1)
        test_loader = DataLoader(test_data,
                                 num_workers=args.num_workers,
                                 batch_size=args.batch_size,
                                 drop_last=False,
                                 pin_memory=True)
        data_config = test_data.config

    # model
    network_module = import_module(
        args.network_config.replace('.py', '').replace('/', '.'))
    network_options = {k: ast.literal_eval(v) for k, v in args.network_option}
    if args.export_onnx:
        network_options['for_inference'] = True
    model, model_info = network_module.get_model(data_config,
                                                 **network_options)
    _logger.info(model)

    # export to ONNX
    if args.export_onnx:
        assert (args.export_onnx.endswith('.onnx'))
        model_path = args.model_prefix
        _logger.info('Exporting model %s to ONNX' % model_path)
        model.load_state_dict(torch.load(model_path, map_location='cpu'))
        model = model.cpu()
        model.eval()

        os.makedirs(os.path.dirname(args.export_onnx), exist_ok=True)
        inputs = tuple(
            torch.ones(model_info['input_shapes'][k], dtype=torch.float32)
            for k in model_info['input_names'])
        torch.onnx.export(model,
                          inputs,
                          args.export_onnx,
                          input_names=model_info['input_names'],
                          output_names=model_info['output_names'],
                          dynamic_axes=model_info.get('dynamic_axes', None),
                          opset_version=11)
        _logger.info('ONNX model saved to %s', args.export_onnx)
        return

    # note: we should always save/load the state_dict of the original model, not the one wrapped by nn.DataParallel
    # so we do not convert it to nn.DataParallel now
    model = model.to(dev)

    # loss function
    try:
        loss_func = network_module.get_loss(data_config, **network_options)
        _logger.info(loss_func)
    except AttributeError:
        loss_func = torch.nn.CrossEntropyLoss()
        _logger.warning(
            'Loss function not defined in %s. Will use `torch.nn.CrossEntropyLoss()` by default.',
            args.network_config)

    if training_mode:
        # optimizer & learning rate
        if args.optimizer == 'adam':
            opt = torch.optim.Adam(model.parameters(), lr=args.start_lr)
            if args.lr_finder is None:
                lr_steps = [int(x) for x in args.lr_steps.split(',')]
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    opt, milestones=lr_steps, gamma=0.1)
        else:
            from utils.nn.optimizer.ranger import Ranger
            opt = Ranger(model.parameters(), lr=args.start_lr)
            if args.lr_finder is None:
                lr_decay_epochs = max(1, int(args.num_epochs * 0.3))
                lr_decay_rate = 0.01**(1. / lr_decay_epochs)
                scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    opt,
                    milestones=list(
                        range(args.num_epochs - lr_decay_epochs,
                              args.num_epochs)),
                    gamma=lr_decay_rate)

        # TODO: mixed precision training
        if args.use_amp:
            #             model, opt = amp.initialize(
            #                model, opt, opt_level="O2",
            #                keep_batchnorm_fp32=True, loss_scale="dynamic"
            #             )
            model, opt = amp.initialize(model,
                                        opt,
                                        opt_level="O1",
                                        keep_batchnorm_fp32=None,
                                        loss_scale="dynamic")

        # load previous training and resume if `--load-epoch` is set
        if args.load_epoch is not None:
            _logger.info('Resume training from epoch %d' % args.load_epoch)
            model_state = torch.load(args.model_prefix +
                                     '_epoch-%d_state.pt' % args.load_epoch,
                                     map_location=dev)
            model.load_state_dict(model_state)
            opt_state = torch.load(args.model_prefix +
                                   '_epoch-%d_optimizer.pt' % args.load_epoch,
                                   map_location=dev)
            opt.load_state_dict(opt_state)

        # mutli-gpu
        if gpus is not None and len(gpus) > 1:
            model = torch.nn.DataParallel(
                model, device_ids=gpus
            )  # model becomes `torch.nn.DataParallel` w/ model.module being the orignal `torch.nn.Module`
        model = model.to(dev)

        # lr finder: keep it after all other setups
        if args.lr_finder is not None:
            start_lr, end_lr, num_iter = args.lr_finder.replace(' ',
                                                                '').split(',')
            from utils.lr_finder import LRFinder
            lr_finder = LRFinder(model,
                                 opt,
                                 loss_func,
                                 device=dev,
                                 input_names=train_data.config.input_names,
                                 label_names=train_data.config.label_names)
            lr_finder.range_test(train_loader,
                                 start_lr=float(start_lr),
                                 end_lr=float(end_lr),
                                 num_iter=int(num_iter))
            lr_finder.plot(output='lr_finder.png'
                           )  # to inspect the loss-learning rate graph
            return

        # training loop
        best_valid_acc = 0
        for epoch in range(args.num_epochs):
            if args.load_epoch is not None:
                if epoch <= args.load_epoch:
                    continue
            print('-' * 50)
            _logger.info('Epoch #%d training' % epoch)
            train(model, loss_func, opt, scheduler, train_loader, dev)
            if args.model_prefix:
                dirname = os.path.dirname(args.model_prefix)
                if dirname and not os.path.exists(dirname):
                    os.makedirs(dirname)
                state_dict = model.module.state_dict() if isinstance(
                    model, torch.nn.DataParallel) else model.state_dict()
                torch.save(state_dict,
                           args.model_prefix + '_epoch-%d_state.pt' % epoch)
                torch.save(
                    opt.state_dict(),
                    args.model_prefix + '_epoch-%d_optimizer.pt' % epoch)

            _logger.info('Epoch #%d validating' % epoch)
            valid_acc = evaluate(model, val_loader, dev, loss_func=loss_func)
            if valid_acc > best_valid_acc:
                best_valid_acc = valid_acc
                if args.model_prefix:
                    shutil.copy2(
                        args.model_prefix + '_epoch-%d_state.pt' % epoch,
                        args.model_prefix + '_best_acc_state.pt')
                    torch.save(model, args.model_prefix + '_best_acc_full.pt')
            _logger.info(
                'Epoch #%d: Current validation acc: %.5f (best: %.5f)' %
                (epoch, valid_acc, best_valid_acc))
    else:
        # run prediction
        if args.model_prefix.endswith('.onnx'):
            _logger.info('Loading model %s for eval' % args.model_prefix)
            from utils.nn.tools import evaluate_onnx
            test_acc, scores, labels, observers = evaluate_onnx(
                args.model_prefix, test_loader)
        else:
            model_path = args.model_prefix if args.model_prefix.endswith(
                '.pt') else args.model_prefix + '_best_acc_state.pt'
            _logger.info('Loading model %s for eval' % model_path)
            model.load_state_dict(torch.load(model_path, map_location=dev))
            if gpus is not None and len(gpus) > 1:
                model = torch.nn.DataParallel(model, device_ids=gpus)
            model = model.to(dev)
            test_acc, scores, labels, observers = evaluate(model,
                                                           test_loader,
                                                           dev,
                                                           for_training=False)
        _logger.info('Test acc %.5f' % test_acc)

        if args.predict_output:
            os.makedirs(os.path.dirname(args.predict_output), exist_ok=True)
            if args.predict_output.endswith('.root'):
                from utils.data.fileio import _write_root
                output = {}
                for idx, label_name in enumerate(data_config.label_value):
                    output[label_name] = (
                        labels[data_config.label_names[0]] == idx)
                    output['score_' + label_name] = scores[:, idx]
                for k, v in labels.items():
                    if k == data_config.label_names[0]:
                        continue
                    if v.ndim > 1:
                        _logger.warning('Ignoring %s, not a 1d array.', k)
                        continue
                    output[k] = v
                for k, v in observers.items():
                    if v.ndim > 1:
                        _logger.warning('Ignoring %s, not a 1d array.', k)
                        continue
                    output[k] = v
                _write_root(args.predict_output, output)
            else:
                import awkward
                output = {'scores': scores}
                output.update(labels)
                output.update(observers)
                awkward.save(args.predict_output, output, mode='w')

            _logger.info('Written output to %s' % args.predict_output)
Ejemplo n.º 6
0
def train_load(args):
    """
    Loads the training data.
    :param args:
    :return: train_loader, val_loader, data_config, train_inputs
    """

    train_files = to_filelist(args, 'train')
    if args.data_val:
        val_files = to_filelist(args, 'val')
        train_range = val_range = (0, 1)
    else:
        val_files = train_files
        train_range = (0, args.train_val_split)
        val_range = (args.train_val_split, 1)
    _logger.info('Using %d files for training, range: %s' %
                 (len(train_files), str(train_range)))
    _logger.info('Using %d files for validation, range: %s' %
                 (len(val_files), str(val_range)))

    if args.demo:
        train_files = train_files[:20]
        val_files = val_files[:20]
        _logger.info(train_files)
        _logger.info(val_files)
        args.data_fraction = 0.1
        args.fetch_step = 0.002

    if args.in_memory and (args.steps_per_epoch is None
                           or args.steps_per_epoch_val is None):
        raise RuntimeError(
            'Must set --steps-per-epoch when using --in-memory!')

    train_data = SimpleIterDataset(
        train_files,
        args.data_config,
        for_training=True,
        load_range_and_fraction=(train_range, args.data_fraction),
        file_fraction=args.file_fraction,
        fetch_by_files=args.fetch_by_files,
        fetch_step=args.fetch_step,
        infinity_mode=args.steps_per_epoch is not None,
        in_memory=args.in_memory)
    val_data = SimpleIterDataset(val_files,
                                 args.data_config,
                                 for_training=True,
                                 load_range_and_fraction=(val_range,
                                                          args.data_fraction),
                                 file_fraction=args.file_fraction,
                                 fetch_by_files=args.fetch_by_files,
                                 fetch_step=args.fetch_step,
                                 infinity_mode=args.steps_per_epoch_val
                                 is not None,
                                 in_memory=args.in_memory)
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              drop_last=True,
                              pin_memory=True,
                              num_workers=min(
                                  args.num_workers,
                                  int(len(train_files) * args.file_fraction)),
                              persistent_workers=args.num_workers > 0
                              and args.steps_per_epoch is not None)
    val_loader = DataLoader(val_data,
                            batch_size=args.batch_size,
                            drop_last=True,
                            pin_memory=True,
                            num_workers=min(
                                args.num_workers,
                                int(len(val_files) * args.file_fraction)),
                            persistent_workers=args.num_workers > 0
                            and args.steps_per_epoch_val is not None)
    data_config = train_data.config
    train_input_names = train_data.config.input_names
    train_label_names = train_data.config.label_names

    return train_loader, val_loader, data_config, train_input_names, train_label_names
Ejemplo n.º 7
0
def train_load(args):
    """
    Loads the training data.
    :param args:
    :return: train_loader, val_loader, data_config, train_inputs
    """
    filelist = sorted(sum([glob.glob(f) for f in args.data_train], []))
    if args.copy_inputs:
        import tempfile
        tmpdir = tempfile.mkdtemp()
        if os.path.exists(tmpdir):
            shutil.rmtree(tmpdir)
        new_filelist = []
        for src in filelist:
            dest = os.path.join(tmpdir, src.lstrip('/'))
            if not os.path.exists(os.path.dirname(dest)):
                os.makedirs(os.path.dirname(dest), exist_ok=True)
            shutil.copy2(src, dest)
            _logger.info('Copied file %s to %s' % (src, dest))
            new_filelist.append(dest)
        filelist = new_filelist

    # np.random.seed(1)
    np.random.shuffle(filelist)
    if args.demo:
        filelist = filelist[:20]
        _logger.info(filelist)
        args.data_fraction = 0.1
        args.fetch_step = 0.002
    num_workers = min(args.num_workers,
                      int(len(filelist) * args.file_fraction))
    train_data = SimpleIterDataset(
        filelist,
        args.data_config,
        for_training=True,
        load_range_and_fraction=((0, args.train_val_split),
                                 args.data_fraction),
        file_fraction=args.file_fraction,
        fetch_by_files=args.fetch_by_files,
        fetch_step=args.fetch_step,
        infinity_mode=args.steps_per_epoch is not None,
        in_memory=args.in_memory)
    val_data = SimpleIterDataset(
        filelist,
        args.data_config,
        for_training=True,
        load_range_and_fraction=((args.train_val_split, 1),
                                 args.data_fraction),
        file_fraction=args.file_fraction,
        fetch_by_files=args.fetch_by_files,
        fetch_step=args.fetch_step,
        infinity_mode=args.steps_per_epoch is not None,
        in_memory=args.in_memory)
    persistent_workers = num_workers > 0 and args.steps_per_epoch is not None
    train_loader = DataLoader(train_data,
                              num_workers=num_workers,
                              batch_size=args.batch_size,
                              drop_last=True,
                              pin_memory=True,
                              persistent_workers=persistent_workers)
    val_loader = DataLoader(val_data,
                            num_workers=num_workers,
                            batch_size=args.batch_size,
                            drop_last=True,
                            pin_memory=True,
                            persistent_workers=persistent_workers)
    data_config = train_data.config
    train_input_names = train_data.config.input_names
    train_label_names = train_data.config.label_names

    return train_loader, val_loader, data_config, train_input_names, train_label_names