示例#1
0
def cli_main(
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
) -> None:
    parser = options.get_training_parser()
    options.add_pruning_args(parser)
    args = options.parse_args_and_arch(parser, modify_parser=modify_parser)

    cfg = convert_namespace_to_omegaconf(args)

    if args.profile:
        with torch.cuda.profiler.profile():
            with torch.autograd.profiler.emit_nvtx():
                distributed_utils.call_main(cfg, main)
    else:
        distributed_utils.call_main(cfg, main)
示例#2
0
文件: validate.py 项目: zvict/NSVF
def cli_main():
    parser = options.get_validation_parser()
    args = options.parse_args_and_arch(parser)

    # only override args that are explicitly given on the command line
    override_parser = options.get_validation_parser()
    override_args = options.parse_args_and_arch(override_parser,
                                                suppress_defaults=True)

    # support multi-gpu validation, use all available gpus
    default_world_size = max(1, torch.cuda.device_count())
    if args.distributed_world_size < default_world_size:
        args.distributed_world_size = default_world_size
        override_args.distributed_world_size = default_world_size

    distributed_utils.call_main(args, main, override_args=override_args)
示例#3
0
def cli_main(
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
) -> None:
    # print('input:', modify_parser) # None
    parser = options.get_training_parser()
    print('parser:', parser)
    print('\n modify_parser', modify_parser)
    args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
    # print('args', args)
    cfg = convert_namespace_to_omegaconf(args)
    # print('cfg:', cfg)
    if args.profile:
        with torch.cuda.profiler.profile():
            with torch.autograd.profiler.emit_nvtx():
                distributed_utils.call_main(cfg, main)
    else:
        distributed_utils.call_main(cfg, main)
示例#4
0
def _hydra_main(cfg: FairseqConfig, **kwargs) -> float:
    add_defaults(cfg)

    if cfg.common.reset_logging:
        reset_logging()  # Hydra hijacks logging, fix that
    else:
        # check if directly called or called through hydra_main
        if HydraConfig.initialized():
            with open_dict(cfg):
                # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
                cfg.job_logging_cfg = OmegaConf.to_container(
                    HydraConfig.get().job_logging, resolve=True)

    with omegaconf_no_object_check():
        cfg = OmegaConf.create(
            OmegaConf.to_container(cfg, resolve=True, enum_to_str=True))
    OmegaConf.set_struct(cfg, True)

    try:
        if cfg.common.profile:
            with torch.cuda.profiler.profile():
                with torch.autograd.profiler.emit_nvtx():
                    distributed_utils.call_main(cfg, pre_main, **kwargs)
        else:
            distributed_utils.call_main(cfg, pre_main, **kwargs)
    except BaseException as e:
        if not cfg.common.suppress_crashes:
            raise
        else:
            logger.error("Crashed! " + str(e))

    # get best val and return - useful for sweepers
    try:
        best_val = metrics.get_smoothed_value(
            "valid", cfg.checkpoint.best_checkpoint_metric)
    except:
        best_val = None

    if best_val is None:
        best_val = float("inf")

    return best_val
示例#5
0
def cli_main(
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
) -> None:
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
    if args.debug:
        args.device_id = 0
        args.distributed_rank = 0
        args.distributed_world_size = 1
        #args.train_subset = args.valid_subset
        args.num_workers = 0
        args.dropout = 0
    cfg = convert_namespace_to_omegaconf(args)

    if args.profile:
        with torch.cuda.profiler.profile():
            with torch.autograd.profiler.emit_nvtx():
                distributed_utils.call_main(cfg, main)
    else:
        distributed_utils.call_main(cfg, main)
def cli_main(path, model_overrides, name):
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)
    args.max_sentences = 2
    args.tokens_per_sample = 512
    args.context_window = 400
    # args.cpu = True
    # args.num_shards = 300

    gl._init()
    gl.set_value('visualize', True)
    gl.set_value('attn_weight_layers', [0 for _ in range(16)])
    gl.set_value('attn_weight_heads', [0 for _ in range(8)])
    gl.set_value('attn_weight_counts', [0 for _ in range(16)])
    gl.set_value('current_layer', 0)

    args.path = path
    args.model_overrides = model_overrides

    distributed_utils.call_main(args, main)

    if name == 'layer':
        attn_weight_layers = gl.get_value('attn_weight_layers')
        attn_weight_counts = gl.get_value('attn_weight_counts')

        attn_weight_layers = [
            x / attn_weight_counts[idx]
            for idx, x in enumerate(attn_weight_layers)
        ]
        torch.save(attn_weight_layers, path + "." + name)
    elif name == 'head':
        attn_weight_heads = gl.get_value('attn_weight_heads')
        attn_weight_counts = gl.get_value('attn_weight_counts')

        attn_weight_heads = [
            x / attn_weight_counts[idx]
            for idx, x in enumerate(attn_weight_heads)
        ]
        torch.save(attn_weight_heads, path + "." + name)
    else:
        exit(0)
示例#7
0
def cli_main(
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
) -> None:
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
    print_options_meaning_changes(args)

    cfg = convert_namespace_to_omegaconf(args)

    if cfg.common.use_plasma_view:
        server = PlasmaStore(path=cfg.common.plasma_path)
        logger.info(
            f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}"
        )

    if args.profile:
        with torch.cuda.profiler.profile():
            with torch.autograd.profiler.emit_nvtx():
                distributed_utils.call_main(cfg, main)
    else:
        distributed_utils.call_main(cfg, main)
示例#8
0
def cli_main():
    parser = options.get_interactive_generation_parser()
    args = options.parse_args_and_arch(parser)
    distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
示例#9
0
def cli_main():
    parser = options.get_interactive_generation_parser()
    args = options.parse_args_and_arch(parser)
    distributed_utils.call_main(args, main)
示例#10
0
def cli_main():
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)
    distributed_utils.call_main(args, main)
示例#11
0
def cli_main():
    parser = options.get_eval_lm_parser()
    args = options.parse_args_and_arch(parser)

    distributed_utils.call_main(convert_namespace_to_omegaconf(args), main)
示例#12
0
def cli_main():
    parser = options.get_rendering_parser()
    add_distributed_training_args(parser)
    args = options.parse_args_and_arch(parser)
    distributed_utils.call_main(args, main)
示例#13
0
def cli_main():
    parser = options.get_generation_parser(interactive=False)
    parser.add_argument('--no-print', action='store_true')
    parser.add_argument('--truncate-size', default=512, type=int)
    args = options.parse_args_and_arch(parser)
    distributed_utils.call_main(args, main)
示例#14
0
def cli_main():
    parser = options.get_interactive_generation_parser()
    parser.add_argument('--transformer-big-zhen', action='store_true')
    args = options.parse_args_and_arch(parser)
    distributed_utils.call_main(args, main)