Exemplo n.º 1
0
def fork_experiment(from_exp: Path, to_exp: Path, conf: bool, vocab: bool, data: bool, code: bool):
    assert from_exp.exists()
    log.info(f'Fork: {str(from_exp)} → {str(to_exp)}')
    if not to_exp.exists():
        log.info(f"Create dir {str(to_exp)}")
        to_exp.mkdir(parents=True)
    if conf:
        conf_file = to_exp / 'conf.yml'
        IO.maybe_backup(conf_file)
        IO.copy_file(from_exp / 'conf.yml', conf_file)
    if data:
        to_data_dir = (to_exp / 'data')
        from_data_dir = from_exp / 'data'
        if to_data_dir.is_symlink():
            log.info(f"removing the existing data link: {to_data_dir.resolve()}")
            to_data_dir.unlink()
        assert not to_data_dir.exists()
        assert from_data_dir.exists()
        log.info(f"link {to_data_dir} → {from_data_dir}")
        to_data_dir.symlink_to(from_data_dir.resolve())
        (to_exp / '_PREPARED').touch(exist_ok=True)
    if not data and vocab: # just the vocab
        Experiment(from_exp, read_only=True).copy_vocabs(
            Experiment(to_exp, config={'Not': 'Empty'}, read_only=True))

    if code:
        for f in ['rtg.zip', 'githead']:
            src = from_exp / f
            if not src.exists():
                log.warning(f"File Not Found: {src}")
                continue
            IO.copy_file(src, to_exp / f)
Exemplo n.º 2
0
def main():
    args = parse_args()
    seed = args.pop("seed")
    if seed:
        log.info(f"Seed for random number generator: {seed}")
        import random
        import torch
        random.seed(seed)
        torch.manual_seed(seed)

    work_dir = Path(args.pop('work_dir'))
    is_big = load_conf(work_dir / 'conf.yml').get('spark', {})

    if is_big:
        log.info("Big experiment mode enabled; checking pyspark backend")
        try:
            import pyspark
        except:
            log.warning("unable to import pyspark. Please do 'pip install pyspark' and run again")
            raise
        from rtg.big.exp import BigTranslationExperiment
        exp = BigTranslationExperiment(work_dir=work_dir)
    else:
        exp = Experiment(work_dir=work_dir)
    assert exp.has_prepared(), f'Experiment dir {exp.work_dir} is not ready to train. ' \
                               f'Please run "prep" sub task'
    exp.train(args)
Exemplo n.º 3
0
def main(args=None):
    args = args or parse_args()
    exp = Experiment(args.exp)
    assert exp.train_db.exists()
    train_data = SqliteFile(exp.train_db)
    print(
        f"Experiment: {exp.work_dir} shared_vocab:{exp.src_vocab is exp.tgt_vocab}"
    )
    for side, var in [('src', 'x'), ('tgt', 'y')]:
        term_freqs = coll.Counter(tok for rec in train_data.get_all([var])
                                  for tok in rec[var])
        lens = np.array(
            list(rec[f'{var}_len']
                 for rec in train_data.get_all([f'{var}_len'])))
        tot_toks = sum(lens)
        n_types = dict(src=len(exp.src_vocab), tgt=len(exp.tgt_vocab))[side]
        assert sum(term_freqs.values()) == tot_toks

        uniform = 1 / n_types
        probs = [freq / tot_toks for freq in term_freqs.values()]
        div = 0.5 * sum(abs(uniform - prob) for prob in probs)
        print(
            f"{side} types: {n_types} toks: {tot_toks:,} len_mean: {np.mean(lens):.4f} "
            f"len_median: {np.median(lens)} imbalance: {div:.4f}")
    print(f'n_segs: {len(lens):,}')
Exemplo n.º 4
0
def __test_model__():
    model_args = {
        'enc_layers': 0,
        'dec_layers': 4,
        'hid_size': 64,
        'ff_size': 64,
        'n_heads': 4,
        'activation': 'relu'
    }

    # if you are running this in pycharm, please set Working Dir=<rtg repo base dir> for run config
    dir = 'experiments/sample-exp'
    exp = Experiment(work_dir=dir, read_only=True)

    exp.model_type = 'tfmnmt'
    exp.model_args.update(model_args)
    exp.optim_args[1].update(
        dict(criterion='smooth_kld',
             warmup_steps=500,
             weighing={'gamma': [0.0, 0.5]}))

    trainer = TransformerTrainer(exp=exp, **exp.optim_args[1])
    assert 2 == exp.tgt_vocab.bos_idx
    batch_size = 256
    steps = 2000
    check_point = 200
    trainer.train(steps=steps, check_point=check_point, batch_size=batch_size)
Exemplo n.º 5
0
def main(args=None):
    args = args or parse_args()
    exp = Experiment(args.exp, read_only=True)
    n_classes = len(exp.tgt_vocab)
    freqs = get_training_frequencies(args.freq, n_classes=n_classes)
    tokr = partial(exp.tgt_vocab.encode_as_ids, add_bos=False, add_eos=False)
    sys = [tokr(line.strip()) for line in args.sys]
    ref = [tokr(line.strip()) for line in args.ref]
    assert len(sys) == len(ref)

    precision, recall = evaluate(sys, ref)
    print('Precsion bias: ', frequency_bias(freqs, precision))
    print('Recall bias: ',  frequency_bias(freqs, recall))
Exemplo n.º 6
0
Arquivo: train.py Projeto: MGheini/rtg
def main():
    args = parse_args()
    seed = args.pop("seed")
    if seed:
        log.info(f"Seed for random number generator: {seed}")
        import random
        import torch
        random.seed(seed)
        torch.manual_seed(seed)

    exp = Experiment(args.pop('work_dir'))
    assert exp.has_prepared(), f'Experiment dir {exp.work_dir} is not ready to train. ' \
                               f'Please run "prep" sub task'
    exp.train(args)
Exemplo n.º 7
0
def main():
    # No grads required
    torch.set_grad_enabled(False)
    args = parse_args()
    gen_args = {}
    exp = Experiment(args.pop('work_dir'), read_only=True)

    assert exp.model_type.endswith('lm'), 'Only for Language models'
    decoder = Decoder.new(exp, gen_args=gen_args, model_paths=args.pop('model_path', None),
                          ensemble=args.pop('ensemble', 1))

    log_pp = log_perplexity(decoder, args['test'])
    print(f'Log perplexity: {log_pp:g}')
    print(f'Perplexity: {math.exp(log_pp):g}')
Exemplo n.º 8
0
def main():
    # No grads required
    torch.set_grad_enabled(False)
    args = parse_args()
    gen_args = {}
    exp = Experiment(args.pop('work_dir'), read_only=True)
    validate_args(args, exp)

    if exp.model_type == 'binmt':
        if not args.get('path'):
            Exception('--binmt-path argument is needed for BiNMT model.')
        gen_args['path'] = args.pop('binmt_path')

    weights = args.get('weights')
    if weights:
        decoder = Decoder.combo_new(exp,
                                    model_paths=args.pop('model_path'),
                                    weights=weights)
    else:
        decoder = Decoder.new(exp,
                              gen_args=gen_args,
                              model_paths=args.pop('model_path', None),
                              ensemble=args.pop('ensemble', 1))
    if args.pop('interactive'):
        if weights:
            log.warning(
                "Interactive shell not reloadable for combo mode. FIXME: TODO:"
            )
        if args['input'] != sys.stdin or args['output'] != sys.stdout:
            log.warning(
                '--input and --output args are not applicable in --interactive mode'
            )
        args.pop('input')
        args.pop('output')

        while True:
            try:
                # an hacky way to unload and reload model when user tries to switch models
                decoder.decode_interactive(**args)
                break  # exit loop if there is no request for reload
            except ReloadEvent as re:
                decoder = Decoder.new(exp,
                                      gen_args=gen_args,
                                      model_paths=re.model_paths)
                args = re.state
                # go back to loop and redo interactive shell
    else:
        return decoder.decode_file(args.pop('input'), args.pop('output'),
                                   **args)
Exemplo n.º 9
0
def parse_args():
    parser = argparse.ArgumentParser(prog="rtg.prep",
                                     description="prepare NMT experiment")
    parser.add_argument("exp",
                        help="Working directory of experiment",
                        type=Path)
    parser.add_argument(
        "conf",
        type=Path,
        nargs='?',
        help="Config File. By default <work_dir>/conf.yml is used")
    args = parser.parse_args()
    conf_file: Path = args.conf if args.conf else args.exp / 'conf.yml'
    assert conf_file.exists()
    return Experiment(args.exp, config=conf_file)
Exemplo n.º 10
0
def main():
    # No grads required for decode
    torch.set_grad_enabled(False)
    cli_args = parse_args()
    exp = Experiment(cli_args.pop('exp_dir'), read_only=True)
    dec_args = exp.config.get('decoder') or exp.config['tester'].get('decoder', {})
    validate_args(cli_args, dec_args, exp)

    decoder = Decoder.new(exp, ensemble=dec_args.pop('ensemble', 1))
    for inp, out in zip(cli_args['input'], cli_args['output']):
        log.info(f"Decode :: {inp} -> {out}")
        try:
            if cli_args.get('no_buffer'):
                return decoder.decode_stream(inp, out, **dec_args)
            else:
                return decoder.decode_file(inp, out, **dec_args)
        except:
            log.exception(f"Decode failed for {inp}")
Exemplo n.º 11
0
def main(args):
    work_dir: Path = args.pop('exp')

    work_dir.mkdir(exist_ok=True, parents=True)
    log.info(f"Setting up a dummy experiment at {work_dir}")
    num_train, num_val = args.pop('num_train'), args.pop('num_val')

    train_data = generate_parallel(**args, num_exs=num_train)
    val_data = generate_parallel(**args, num_exs=num_val)

    train_files = str(work_dir / 'train.raw.src'), str(work_dir /
                                                       'train.raw.tgt')
    val_files = str(work_dir / 'valid.raw.src'), str(work_dir /
                                                     'valid.raw.tgt')
    write_parallel(train_data, *train_files)
    write_parallel(val_data, *val_files)

    config = {
        'prep': {
            'train_src': train_files[0],
            'train_tgt': train_files[1],
            'valid_src': val_files[0],
            'valid_tgt': val_files[1],
            'pieces': 'word',
            'truncate': True,
            'src_len': args['max_len'],
            'tgt_len': args['max_len'],
        }
    }

    if args.get('rev_vocab'):
        # shared vocabulary would be confusing
        config['prep'].update({
            'shared_vocab': False,
            'max_src_types': args['vocab_size'],
            'max_tgt_types': args['vocab_size']
        })
    else:
        config['prep'].update({
            'shared_vocab': True,
            'max_types': args['vocab_size']
        })
    exp = Experiment(work_dir, config=config)
    exp.store_config()
Exemplo n.º 12
0
def attach_translate_route(cli_args):
    global exp
    exp = Experiment(cli_args.pop("exp_dir"), read_only=True)
    dec_args = exp.config.get("decoder") or exp.config["tester"].get(
        "decoder", {})
    decoder = Decoder.new(exp, ensemble=dec_args.pop("ensemble", 1))
    dataprep = RtgIO(exp=exp)

    @bp.route("/translate", methods=["POST", "GET"])
    def translate():
        if request.method not in ("POST", "GET"):
            return "GET and POST are supported", 400
        if request.method == 'GET':
            sources = request.args.getlist("source", None)
        else:
            sources = (request.json or {}).get(
                'source', None) or request.form.getlist("source")
            if isinstance(sources, str):
                sources = [sources]
        if not sources:
            return "Please submit parameter 'source'", 400
        sources = [dataprep.pre_process(sent) for sent in sources]
        translations = []
        for source in sources:
            translated = decoder.decode_sentence(source, **dec_args)[0][1]
            translated = dataprep.post_process(translated.split())
            translations.append(translated)

        res = dict(source=sources, translation=translations)
        return jsonify(res)

    @bp.route("/conf.yml", methods=["GET"])
    def get_conf():
        conf_str = exp._config_file.read_text(encoding='utf-8',
                                              errors='ignore')
        return render_template('conf.yml.html', conf_str=conf_str)

    @bp.route("/about", methods=["GET"])
    def about():
        def_desc = "Model description is unavailable; please update conf.yml"
        return render_template('about.html',
                               model_desc=exp.config.get(
                                   "description", def_desc))
Exemplo n.º 13
0
def test_lm():
    #model, args = RnnLm.make('eng', 8000)
    work_dir = '/Users/tg/work/me/rtg/saral/runs/1S-rnnlm-basic'
    exp = Experiment(work_dir)
    trainer = RnnLmTrainer(exp=exp)
    trainer.train(steps=2000, check_point=100, batch_size=64)