Exemple #1
0
def main():
    # No grads required
    torch.set_grad_enabled(False)
    args = parse_args()
    gen_args = {}
    exp = Experiment(args.pop('work_dir'), read_only=True)
    validate_args(args, exp)

    if exp.model_type == 'binmt':
        if not args.get('path'):
            Exception('--binmt-path argument is needed for BiNMT model.')
        gen_args['path'] = args.pop('binmt_path')

    weights = args.get('weights')
    if weights:
        decoder = Decoder.combo_new(exp,
                                    model_paths=args.pop('model_path'),
                                    weights=weights)
    else:
        decoder = Decoder.new(exp,
                              gen_args=gen_args,
                              model_paths=args.pop('model_path', None),
                              ensemble=args.pop('ensemble', 1))
    if args.pop('interactive'):
        if weights:
            log.warning(
                "Interactive shell not reloadable for combo mode. FIXME: TODO:"
            )
        if args['input'] != sys.stdin or args['output'] != sys.stdout:
            log.warning(
                '--input and --output args are not applicable in --interactive mode'
            )
        args.pop('input')
        args.pop('output')

        while True:
            try:
                # an hacky way to unload and reload model when user tries to switch models
                decoder.decode_interactive(**args)
                break  # exit loop if there is no request for reload
            except ReloadEvent as re:
                decoder = Decoder.new(exp,
                                      gen_args=gen_args,
                                      model_paths=re.model_paths)
                args = re.state
                # go back to loop and redo interactive shell
    else:
        return decoder.decode_file(args.pop('input'), args.pop('output'),
                                   **args)
Exemple #2
0
def __test_seq2seq_model__():
    """
        batch_size = 4
        p = '/Users/tg/work/me/rtg/saral/runs/1S-rnn-basic'
        exp = Experiment(p)
        steps = 3000
        check_pt = 100
        trainer = SteppedRNNNMTTrainer(exp=exp, lr=0.01, warmup_steps=100)
        trainer.train(steps=steps, check_point=check_pt, batch_size=batch_size)
    """
    from rtg.dummy import DummyExperiment
    from rtg.module.decoder import Decoder

    vocab_size = 50
    batch_size = 30
    exp = DummyExperiment("tmp.work",
                          config={'model_type': 'seq'
                                  '2seq'},
                          read_only=True,
                          vocab_size=vocab_size)
    emb_size = 100
    model_dim = 100
    steps = 3000
    check_pt = 100

    assert 2 == Batch.bos_val
    src = tensor([[4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
                  [13, 12, 11, 10, 9, 8, 7, 6, 5, 4]])
    src_lens = tensor([src.size(1)] * src.size(0))

    for reverse in (False, ):
        # train two models;
        #  first, just copy the numbers, i.e. y = x
        #  second, reverse the numbers y=(V + reserved - x)
        log.info(f"====== REVERSE={reverse}; VOCAB={vocab_size}======")
        model, args = RNNMT.make_model('DummyA',
                                       'DummyB',
                                       vocab_size,
                                       vocab_size,
                                       attention='dot',
                                       emb_size=emb_size,
                                       hid_size=model_dim,
                                       n_layers=1)
        trainer = SteppedRNNMTTrainer(exp=exp,
                                      model=model,
                                      lr=0.01,
                                      warmup_steps=100)
        decr = Decoder.new(exp, model)

        def check_pt_callback(**args):
            res = decr.greedy_decode(src, src_lens, max_len=17)
            for score, seq in res:
                log.info(f'{score:.4f} :: {seq}')

        trainer.train(steps=steps,
                      check_point=check_pt,
                      batch_size=batch_size,
                      check_pt_callback=check_pt_callback)
Exemple #3
0
def __test_model__():
    from rtg.dummy import DummyExperiment
    vocab_size = 30
    args = {
        'src_vocab': vocab_size,
        'tgt_vocab': vocab_size,
        'enc_layers': 0,
        'dec_layers': 4,
        'hid_size': 64,
        'ff_size': 64,
        'n_heads': 4,
        'activation': 'gelu'
    }
    if False:
        for n, p in model.named_parameters():
            print(n, p.shape)

    from rtg.module.decoder import Decoder

    config = {
        'model_type': 'tfmnmt',
        'trainer': {
            'init_args': {
                'chunk_size': 2
            }
        }
    }
    exp = DummyExperiment("work.tmp.t2t",
                          config=config,
                          read_only=True,
                          vocab_size=vocab_size)
    exp.model_args = args
    trainer = TransformerTrainer(exp=exp, warmup_steps=200)
    decr = Decoder.new(exp, trainer.model)

    assert 2 == Batch.bos_val
    src = tensor(
        [[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, Batch.eos_val, Batch.pad_value],
         [
             13, 12, 11, 10, 9, 8, 7, 6, Batch.eos_val, Batch.pad_value,
             Batch.pad_value, Batch.pad_value
         ]])
    src_lens = tensor([src.size(1)] * src.size(0))

    def check_pt_callback(**args):
        res = decr.greedy_decode(src, src_lens, max_len=12)
        for score, seq in res:
            log.info(f'{score:.4f} :: {seq}')

    batch_size = 50
    steps = 1000
    check_point = 50
    trainer.train(steps=steps,
                  check_point=check_point,
                  batch_size=batch_size,
                  check_pt_callback=check_pt_callback)
Exemple #4
0
def __test_model__():
    from rtg.data.dummy import DummyExperiment
    from rtg import Batch, my_tensor as tensor

    vocab_size = 24
    args = {
        'src_vocab': vocab_size,
        'tgt_vocab': vocab_size,
        'enc_layers': 0,
        'dec_layers': 4,
        'hid_size': 32,
        'eff_dims': [],
        'dff_dims': [64, 128, 128, 64],
        'enc_depth_probs': [],
        'dec_depth_probs': [1.0, 0.75, 0.5, 0.75],
        'n_heads': 4,
        'activation': 'relu'
    }

    from rtg.module.decoder import Decoder

    config = {
        'model_type': 'wvskptfmnmt',
        'trainer': {'init_args': {'chunk_size': 2, 'grad_accum': 5}},
        'optim': {
            'args': {
                # "cross_entropy", "smooth_kld", "binary_cross_entropy", "triplet_loss"
                'criterion': "smooth_kld",
                'lr': 0.01,
                'inv_sqrt': True
            }
        }
    }

    exp = DummyExperiment("work.tmp.wvskptfmnmt", config=config, read_only=True,
                          vocab_size=vocab_size)
    exp.model_args = args
    trainer = WVSKPTransformerTrainer(exp=exp, warmup_steps=200, **config['optim']['args'])
    decr = Decoder.new(exp, trainer.model)

    assert 2 == Batch.bos_val
    src = tensor([[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, Batch.eos_val, Batch.pad_value],
                  [13, 12, 11, 10, 9, 8, 7, 6, Batch.eos_val, Batch.pad_value, Batch.pad_value,
                   Batch.pad_value]])
    src_lens = tensor([src.size(1)] * src.size(0))

    def check_pt_callback(**args):
        res = decr.greedy_decode(src, src_lens, max_len=12)
        for score, seq in res:
            log.info(f'{score:.4f} :: {seq}')

    batch_size = 50
    steps = 200
    check_point = 10
    trainer.train(steps=steps, check_point=check_point, batch_size=batch_size,
                  check_pt_callback=check_pt_callback)
Exemple #5
0
    def make_model(cls,
                   src_vocab,
                   tgt_vocab,
                   enc_layers=6,
                   dec_layers=6,
                   hid_size=512,
                   ff_size=2048,
                   n_heads=8,
                   dropout=0.1,
                   tied_emb='three-way',
                   activation='relu',
                   exp: Experiment = None):
        "Helper: Construct a model from hyper parameters."

        # get all args for reconstruction at a later phase
        _, _, _, args = inspect.getargvalues(inspect.currentframe())
        for exclusion in ['cls', 'exp']:
            del args[exclusion]  # exclude some args
        # In case you are wondering, why I didnt use **kwargs here:
        #   these args are read from conf file where user can introduce errors, so the parameter
        #   validation and default value assignment is implicitly done by function call for us :)
        assert activation in {'relu', 'elu', 'gelu'}
        log.info(f"Make model, Args={args}")
        c = copy.deepcopy
        attn = MultiHeadedAttention(n_heads, hid_size, dropout=dropout)
        ff = PositionwiseFeedForward(hid_size,
                                     ff_size,
                                     dropout,
                                     activation=activation)

        if enc_layers == 0:
            log.warning("Zero encoder layers!")
        encoder = Encoder(EncoderLayer(hid_size, c(attn), c(ff), dropout),
                          enc_layers)

        assert dec_layers > 0
        decoder = Decoder(
            DecoderLayer(hid_size, c(attn), c(attn), c(ff), dropout),
            dec_layers)

        src_emb = nn.Sequential(Embeddings(hid_size, src_vocab),
                                PositionalEncoding(hid_size, dropout))
        tgt_emb = nn.Sequential(Embeddings(hid_size, tgt_vocab),
                                PositionalEncoding(hid_size, dropout))
        generator = Generator(hid_size, tgt_vocab)

        model = cls(encoder, decoder, src_emb, tgt_emb, generator)

        if tied_emb:
            model.tie_embeddings(tied_emb)

        model.init_params()
        return model, args
Exemple #6
0
def main():
    # No grads required
    torch.set_grad_enabled(False)
    args = parse_args()
    gen_args = {}
    exp = Experiment(args.pop('work_dir'), read_only=True)

    assert exp.model_type.endswith('lm'), 'Only for Language models'
    decoder = Decoder.new(exp, gen_args=gen_args, model_paths=args.pop('model_path', None),
                          ensemble=args.pop('ensemble', 1))

    log_pp = log_perplexity(decoder, args['test'])
    print(f'Log perplexity: {log_pp:g}')
    print(f'Perplexity: {math.exp(log_pp):g}')
Exemple #7
0
def main():
    # No grads required for decode
    torch.set_grad_enabled(False)
    cli_args = parse_args()
    exp = Experiment(cli_args.pop('exp_dir'), read_only=True)
    dec_args = exp.config.get('decoder') or exp.config['tester'].get('decoder', {})
    validate_args(cli_args, dec_args, exp)

    decoder = Decoder.new(exp, ensemble=dec_args.pop('ensemble', 1))
    for inp, out in zip(cli_args['input'], cli_args['output']):
        log.info(f"Decode :: {inp} -> {out}")
        try:
            if cli_args.get('no_buffer'):
                return decoder.decode_stream(inp, out, **dec_args)
            else:
                return decoder.decode_file(inp, out, **dec_args)
        except:
            log.exception(f"Decode failed for {inp}")
Exemple #8
0
def attach_translate_route(cli_args):
    global exp
    exp = Experiment(cli_args.pop("exp_dir"), read_only=True)
    dec_args = exp.config.get("decoder") or exp.config["tester"].get(
        "decoder", {})
    decoder = Decoder.new(exp, ensemble=dec_args.pop("ensemble", 1))
    dataprep = RtgIO(exp=exp)

    @bp.route("/translate", methods=["POST", "GET"])
    def translate():
        if request.method not in ("POST", "GET"):
            return "GET and POST are supported", 400
        if request.method == 'GET':
            sources = request.args.getlist("source", None)
        else:
            sources = (request.json or {}).get(
                'source', None) or request.form.getlist("source")
            if isinstance(sources, str):
                sources = [sources]
        if not sources:
            return "Please submit parameter 'source'", 400
        sources = [dataprep.pre_process(sent) for sent in sources]
        translations = []
        for source in sources:
            translated = decoder.decode_sentence(source, **dec_args)[0][1]
            translated = dataprep.post_process(translated.split())
            translations.append(translated)

        res = dict(source=sources, translation=translations)
        return jsonify(res)

    @bp.route("/conf.yml", methods=["GET"])
    def get_conf():
        conf_str = exp._config_file.read_text(encoding='utf-8',
                                              errors='ignore')
        return render_template('conf.yml.html', conf_str=conf_str)

    @bp.route("/about", methods=["GET"])
    def about():
        def_desc = "Model description is unavailable; please update conf.yml"
        return render_template('about.html',
                               model_desc=exp.config.get(
                                   "description", def_desc))
Exemple #9
0
def __test_model__():
    from rtg.dummy import DummyExperiment
    from rtg import Batch, my_tensor as tensor

    vocab_size = 24
    args = {
        'src_vocab': vocab_size,
        'tgt_vocab': vocab_size,
        'enc_layers': 4,
        'dec_layers': 3,
        'hid_size': 128,
        'ff_size': 256,
        'dec_rnn_type': 'GRU',
        'enc_heads': 4
    }

    from rtg.module.decoder import Decoder

    exp = DummyExperiment("work.tmp.hybridmt",
                          config={'model_type': 'hybridmt'},
                          read_only=True,
                          vocab_size=vocab_size)
    exp.model_args = args
    trainer = HybridMTTrainer(exp=exp, warmup_steps=200)
    decr = Decoder.new(exp, trainer.model)

    assert 2 == Batch.bos_val
    src = tensor([[4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
                  [13, 12, 11, 10, 9, 8, 7, 6, 5, 4]])
    src_lens = tensor([src.size(1)] * src.size(0))

    def check_pt_callback(**args):
        res = decr.greedy_decode(src, src_lens, max_len=12)
        for score, seq in res:
            log.info(f'{score:.4f} :: {seq}')

    batch_size = 50
    steps = 2000
    check_point = 50
    trainer.train(steps=steps,
                  check_point=check_point,
                  batch_size=batch_size,
                  check_pt_callback=check_pt_callback)
Exemple #10
0
def log_perplexity(decoder: Decoder, test_data: TextIO):
    """
    Computes log perplexity of a language model on a given test data
    :param decoder:
    :param test_data:
    :return:

    .. math::
        P(w_i | h) <-- probability of word w_i given history h
        P(w_1, w_2, ... w_N) <-- probability of observing or generating a word sequence
        P(w_1, w_2, ... w_N) = P(w_1) x P(w_2|w_1) x P(w_3 | w_1, w_2) ... <-- chain rule

        PP_M <-- Perplexity of a Model M
        PP_M(w_1, w_2, ... w_N) <-- PP_M on a sequence w_1, w_2, ... w_N
        PP_M(w_1, w_2, ... w_N) = P(w_1, w_2, ... w_N)^{-1/N}

        log(PP_M) <-- Log perplexity of a model M
        log(PP_M) = -1/N \sum_{i=1}^{i=N} P(w_i | w_1, w_2 .. w_{i-1})


    Note: log perplexity is a practical solution to deal with floating point underflow
    """
    lines = (line.strip() for line in test_data)
    test_seqs = [decoder.out_vocab.encode_as_ids(line, add_bos=True, add_eos=True)
                 for line in lines]
    count = 0
    total = 0.0
    for seq in tqdm(test_seqs, dynamic_ncols=True):
        #  batch of 1
        # TODO: make this faster using bigger batching
        batch = torch.tensor(seq, dtype=torch.long, device=device).view(1, -1)
        for step in range(1, len(seq)):
            # assumption: BOS and EOS are included
            count += 1
            history = batch[:, :step]
            word_idx = seq[step]
            log_prob = decoder.next_word_distr(history)[0, word_idx]
            total += log_prob
    log_pp = -1/count * total
    return log_pp.item()
Exemple #11
0
def __test_binmt_model__():
    from rtg.module.decoder import Decoder

    vocab_size = 20
    exp = Experiment("tmp.work",
                     config={'model_type': 'binmt'},
                     read_only=True)
    num_epoch = 100
    emb_size = 100
    model_dim = 100
    batch_size = 32

    src = tensor([[2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
                  [2, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4]])
    src_lens = tensor([src.size(1)] * src.size(0))

    for reverse in (False, ):
        # train two models;
        #  first, just copy the numbers, i.e. y = x
        #  second, reverse the numbers y=(V + reserved - x)
        log.info(f"====== REVERSE={reverse}; VOCAB={vocab_size}======")
        model, args = BiNMT.make_model('DummyA',
                                       'DummyB',
                                       vocab_size,
                                       vocab_size,
                                       emb_size=emb_size,
                                       hid_size=model_dim,
                                       n_layers=2)
        trainer = BiNmtTrainer(exp=exp,
                               model=model,
                               lr=0.01,
                               warmup_steps=500,
                               step_size=2 * batch_size)

        decr = Decoder.new(exp, model, gen_args={'path': 'E1D1'})
        assert 2 == Batch.bos_val

        def print_res(res):
            for score, seq in res:
                log.info(f'{score:.4f} :: {seq}')

        for epoch in range(num_epoch):
            model.train()
            train_data = BatchIterable(vocab_size,
                                       batch_size,
                                       50,
                                       seq_len=10,
                                       reverse=reverse,
                                       batch_first=True)
            val_data = BatchIterable(vocab_size,
                                     batch_size,
                                     5,
                                     reverse=reverse,
                                     batch_first=True)
            train_loss = trainer._run_cycle(train_data,
                                            train_data,
                                            train_mode=True)
            val_loss = trainer._run_cycle(val_data, val_data, train_mode=False)
            log.info(
                f"Epoch {epoch}, training Loss: {train_loss:g} \t validation loss:{val_loss:g}"
            )
            model.eval()
            res = decr.greedy_decode(src, src_lens, max_len=17)
            print_res(res)
Exemple #12
0
    def __init__(self,
                 exp: Experiment,
                 model: Optional[NMTModel] = None,
                 model_factory: Optional[Callable] = None,
                 optim: str = 'ADAM',
                 **optim_args):
        self.last_step = -1
        self.exp = exp
        optim_state = None
        if model:
            self.model = model
        else:
            args = exp.model_args
            assert args
            assert model_factory
            self.model, args = model_factory(exp=exp, **args)
            exp.model_args = args
            last_model, self.last_step = self.exp.get_last_saved_model()
            if last_model:
                log.info(
                    f"Resuming training from step:{self.last_step}, model={last_model}"
                )
                state = torch.load(last_model, map_location=device)
                model_state = state[
                    'model_state'] if 'model_state' in state else state

                if 'optim_state' in state:
                    optim_state = state['optim_state']
                self.model.load_state_dict(model_state)
                if 'amp_state' in state and dtorch.fp16:
                    log.info("Restoring  AMP state")
                    dtorch._scaler.load_state_dict(state['amp_state'])
            else:
                log.info(
                    "No earlier check point found. Looks like this is a fresh start"
                )

        # optimizer : default args for missing fields
        for k, v in self.default_optim_args.items():
            optim_args[k] = optim_args.get(k, v)

        self.n_gpus = torch.cuda.device_count()
        self.device_ids = list(range(self.n_gpus))

        inner_opt_args = {
            k: optim_args[k]
            for k in ['lr', 'betas', 'eps', 'weight_decay', 'amsgrad']
        }

        self.core_model = self.model.to(device)

        trainable_params = self.exp.config['optim'].get('trainable', {})
        if trainable_params:
            if drtorch.is_distributed:  # model is wrapped in DP or DistributedDP
                log.warning(
                    f">> Using more than 1 GPU with 'trainable' params is NOT tested"
                )
            trainable_params = self.core_model.get_trainable_params(
                include=trainable_params.get('include'),
                exclude=trainable_params.get('exclude'))
        else:
            trainable_params = self.model.parameters()

        inner_opt = Optims[optim].new(trainable_params, **inner_opt_args)
        self.model = dtorch.maybe_distributed(self.core_model)

        if optim_state:
            log.info("restoring optimizer state from checkpoint")
            try:
                inner_opt.load_state_dict(optim_state)
            except Exception:
                log.exception("Unable to restore optimizer, skipping it.")
        self.opt = NoamOpt(self.core_model.model_dim,
                           optim_args['constant'],
                           optim_args['warmup_steps'],
                           inner_opt,
                           step=self.start_step,
                           inv_sqrt=optim_args['inv_sqrt'])

        if self.exp.read_only:
            self.tbd = NoOpSummaryWriter()
        else:
            self.tbd = SummaryWriter(log_dir=str(exp.work_dir / 'tensorboard'))

        self.exp.optim_args = optim, optim_args
        if not self.exp.read_only:
            self.exp.persist_state()
        self.samples = None
        if exp.samples_file and exp.samples_file.exists():
            with IO.reader(exp.samples_file) as f:
                self.samples = [line.strip().split('\t') for line in f]
                log.info(f"Found {len(self.samples)} sample records")
                if self.start_step == 0:
                    for samp_num, sample in enumerate(self.samples):
                        self.tbd.add_text(f"sample/{samp_num}",
                                          " || ".join(sample), 0)

            from rtg.module.decoder import Decoder
            self.decoder = Decoder.new(self.exp, self.core_model)

        if self.start_step <= 1:
            self.maybe_init_model()

        self.criterion = self.create_criterion(optim_args['criterion'])
Exemple #13
0
    def tune_decoder_params(self,
                            exp: Experiment,
                            tune_src: str,
                            tune_ref: str,
                            batch_size: int,
                            trials: int = 10,
                            lowercase=True,
                            beam_size=(1, 4, 8),
                            ensemble=(1, 5, 10),
                            lp_alpha=(0.0, 0.4, 0.6),
                            suggested: List[Tuple[int, int, float]] = None,
                            **fixed_args):
        _, _, _, tune_args = inspect.getargvalues(inspect.currentframe())
        tune_args.update(fixed_args)
        ex_args = ['exp', 'self', 'fixed_args', 'batch_size', 'max_len']
        if trials == 0:
            ex_args += ['beam_size', 'ensemble', 'lp_alpha']
        for x in ex_args:
            del tune_args[x]  # exclude some args

        _, step = exp.get_last_saved_model()
        tune_dir = exp.work_dir / f'tune_step{step}'
        log.info(f"Tune dir = {tune_dir}")
        tune_dir.mkdir(parents=True, exist_ok=True)
        tune_src, tune_ref = Path(tune_src), Path(tune_ref)
        assert tune_src.exists()
        assert tune_ref.exists()
        tune_src, tune_ref = list(IO.get_lines(tune_src)), list(
            IO.get_lines(tune_ref))
        assert len(tune_src) == len(tune_ref)

        tune_log = tune_dir / 'scores.json'  # resume the tuning
        memory: Dict[Tuple, float] = {}
        if tune_log.exists():
            data = json.load(tune_log.open())
            # JSON keys cant be tuples, so they were stringified
            memory = {eval(k): v for k, v in data.items()}

        beam_sizes, ensembles, lp_alphas = [], [], []
        if suggested:
            if isinstance(suggested[0], str):
                suggested = [eval(x) for x in suggested]
            suggested = [(x[0], x[1], round(x[2], 2)) for x in suggested]
            suggested_new = [x for x in suggested if x not in memory]
            beam_sizes += [x[0] for x in suggested_new]
            ensembles += [x[1] for x in suggested_new]
            lp_alphas += [x[2] for x in suggested_new]

        new_trials = trials - len(memory)
        if new_trials > 0:
            beam_sizes += [random.choice(beam_size) for _ in range(new_trials)]
            ensembles += [random.choice(ensemble) for _ in range(new_trials)]
            lp_alphas += [
                round(random.choice(lp_alpha), 2) for _ in range(new_trials)
            ]

        # ensembling is somewhat costlier, so try minimize the model ensembling, by grouping them together
        grouped_ens = defaultdict(list)
        for b, ens, l in zip(beam_sizes, ensembles, lp_alphas):
            grouped_ens[ens].append((b, l))
        try:
            for ens, args in grouped_ens.items():
                decoder = Decoder.new(exp, ensemble=ens)
                for b_s, lp_a in args:
                    eff_batch_size = batch_size // b_s  # effective batch size
                    name = f'tune_step{step}_beam{b_s}_ens{ens}_lp{lp_a:.2f}'
                    log.info(name)
                    out_file = tune_dir / f'{name}.out.tsv'
                    score = self.decode_eval_file(decoder,
                                                  tune_src,
                                                  out_file,
                                                  tune_ref,
                                                  batch_size=eff_batch_size,
                                                  beam_size=b_s,
                                                  lp_alpha=lp_a,
                                                  lowercase=lowercase,
                                                  **fixed_args)
                    memory[(b_s, ens, lp_a)] = score
            best_params = sorted(memory.items(),
                                 key=lambda x: x[1],
                                 reverse=True)[0][0]
            return dict(zip(['beam_size', 'ensemble', 'lp_alpha'],
                            best_params)), tune_args
        finally:
            # JSON keys cant be tuples, so we stringify them
            data = {str(k): v for k, v in memory.items()}
            IO.write_lines(tune_log, json.dumps(data))
Exemple #14
0
    def __init__(self,
                 exp: Experiment,
                 model: Optional[NMTModel] = None,
                 model_factory: Optional[Callable] = None,
                 optim: str = 'ADAM',
                 **optim_args):
        self.start_step = 0
        self.last_step = -1
        self.exp = exp
        optim_state = None
        if model:
            self.model = model
        else:
            args = exp.model_args
            assert args
            assert model_factory
            self.model, args = model_factory(exp=exp, **args)
            exp.model_args = args
            last_model, self.last_step = self.exp.get_last_saved_model()
            if last_model:
                self.start_step = self.last_step + 1
                log.info(
                    f"Resuming training from step:{self.start_step}, model={last_model}"
                )
                state = torch.load(last_model)
                model_state = state[
                    'model_state'] if 'model_state' in state else state
                if 'optim_state' in state:
                    optim_state = state['optim_state']
                self.model.load_state_dict(model_state)
            else:
                log.info(
                    "No earlier check point found. Looks like this is a fresh start"
                )

        # making optimizer
        optim_args['lr'] = optim_args.get('lr', 0.1)
        optim_args['betas'] = optim_args.get('betas', [0.9, 0.98])
        optim_args['eps'] = optim_args.get('eps', 1e-9)

        warmup_steps = optim_args.pop('warmup_steps', 8000)
        self._smoothing = optim_args.pop('label_smoothing', 0.1)
        constant = optim_args.pop('constant', 2)

        self.model = self.model.to(device)

        inner_opt = Optims[optim].new(self.model.parameters(), **optim_args)
        if optim_state:
            log.info("restoring optimizer state from checkpoint")
            try:
                inner_opt.load_state_dict(optim_state)
            except Exception:
                log.exception("Unable to restore optimizer, skipping it.")
        self.opt = NoamOpt(self.model.model_dim,
                           constant,
                           warmup_steps,
                           inner_opt,
                           step=self.start_step)

        optim_args.update(
            dict(warmup_steps=warmup_steps,
                 label_smoothing=self._smoothing,
                 constant=constant))
        if self.exp.read_only:
            self.tbd = NoOpSummaryWriter()
        else:
            self.tbd = SummaryWriter(log_dir=str(exp.work_dir / 'tensorboard'))

        self.exp.optim_args = optim, optim_args
        if not self.exp.read_only:
            self.exp.persist_state()
        self.samples = None
        if exp.samples_file.exists():
            with IO.reader(exp.samples_file) as f:
                self.samples = [line.strip().split('\t') for line in f]
                log.info(f"Found {len(self.samples)} sample records")
                if self.start_step == 0:
                    for samp_num, sample in enumerate(self.samples):
                        self.tbd.add_text(f"sample/{samp_num}",
                                          " || ".join(sample), 0)

            from rtg.module.decoder import Decoder
            self.decoder = Decoder.new(self.exp, self.model)

        if self.start_step == 0:
            self.init_embeddings()
        self.model = self.model.to(device)
Exemple #15
0
    def run_tests(self, exp=None, args=None):
        exp = exp or self.exp
        args = args or exp.config['tester']
        suit: Dict[str, List] = args['suit']
        assert suit
        log.info(f"Found {len(suit)} suit :: {suit.keys()}")

        _, step = exp.get_last_saved_model()
        if 'decoder' not in args:
            args['decoder'] = {}
        dec_args: Dict = args['decoder']
        best_params = copy.deepcopy(dec_args)
        max_len = best_params.get('max_len', 50)
        batch_size = best_params.get('batch_size', 20_000)
        # TODO: this has grown to become messy (trying to make backward compatible, improve the logic here
        if 'tune' in dec_args and not dec_args['tune'].get('tuned'):
            tune_args: Dict = dec_args['tune']
            prep_args = exp.config['prep']
            if 'tune_src' not in tune_args:
                tune_args['tune_src'] = prep_args['valid_src']
            if 'tune_ref' not in tune_args:
                tune_args['tune_ref'] = prep_args.get('valid_ref',
                                                      prep_args['valid_tgt'])
            best_params, tuner_args_ext = self.tune_decoder_params(
                exp=exp, max_len=max_len, batch_size=batch_size, **tune_args)
            log.info(f"tuner args = {tuner_args_ext}")
            log.info(f"Tuning complete: best_params: {best_params}")
            dec_args['tune'].update(
                tuner_args_ext)  # Update the config file with default args
            dec_args['tune']['tuned'] = True

        if 'tune' in best_params:
            del best_params['tune']

        log.info(f"params: {best_params}")
        beam_size = best_params.get('beam_size', 4)
        ensemble: int = best_params.pop('ensemble', 5)
        lp_alpha = best_params.get('lp_alpha', 0.0)
        eff_batch_size = batch_size // beam_size

        dec_args.update(
            dict(beam_size=beam_size,
                 lp_alpha=lp_alpha,
                 ensemble=ensemble,
                 max_len=max_len,
                 batch_size=batch_size))
        exp.persist_state()  # update the config

        assert step > 0, 'looks like no model is saved or invalid experiment dir'
        test_dir = exp.work_dir / f'test_step{step}_beam{beam_size}_ens{ensemble}_lp{lp_alpha}'
        log.info(f"Test Dir = {test_dir}")
        test_dir.mkdir(parents=True, exist_ok=True)

        decoder = Decoder.new(exp, ensemble=ensemble)
        for name, data in suit.items():
            # noinspection PyBroadException
            src, ref = data, None
            out_file = None
            if isinstance(data, list):
                src, ref = data[:2]
            elif isinstance(data, dict):
                src, ref = data['src'], data.get('ref')
                out_file = data.get('out')
            try:
                orig_src = Path(src).resolve()
                src_link = test_dir / f'{name}.src'
                ref_link = test_dir / f'{name}.ref'
                buffer = [(src_link, orig_src)]
                if ref:
                    orig_ref = Path(ref).resolve()
                    buffer.append((ref_link, orig_ref))
                for link, orig in buffer:
                    if not link.exists():
                        link.symlink_to(orig)
                out_file = test_dir / f'{name}.out.tsv' if not out_file else out_file
                out_file.parent.mkdir(parents=True, exist_ok=True)

                self.decode_eval_file(decoder,
                                      src_link,
                                      out_file,
                                      ref_link,
                                      batch_size=eff_batch_size,
                                      beam_size=beam_size,
                                      lp_alpha=lp_alpha,
                                      max_len=max_len)
            except Exception as e:
                log.exception(f"Something went wrong with '{name}' test")
                err = test_dir / f'{name}.err'
                err.write_text(str(e))
Exemple #16
0
    def export(self,
               target: Path,
               name: str = None,
               ensemble: int = 1,
               copy_config=True,
               copy_vocab=True):
        to_exp = Experiment(target.resolve(), config=self.exp.config)

        if copy_config:
            log.info("Copying config")
            to_exp.persist_state()

        if copy_vocab:
            log.info("Copying vocabulary")
            self.exp.copy_vocabs(to_exp)
        assert ensemble > 0
        assert name
        assert len(name.split()) == 1
        log.info("Going to average models and then copy")
        model_paths = self.exp.list_models()[:ensemble]
        log.info(f'Model paths: {model_paths}')
        chkpt_state = torch.load(model_paths[0], map_location=device)
        if ensemble > 1:
            log.info("Averaging them ...")
            avg_state = Decoder.average_states(model_paths)
            chkpt_state = dict(model_state=avg_state,
                               model_type=chkpt_state['model_type'],
                               model_args=chkpt_state['model_args'])
        log.info("Instantiating it ...")
        model = instantiate_model(checkpt_state=chkpt_state, exp=self.exp)
        log.info(f"Exporting to {target}")
        to_exp = Experiment(target, config=self.exp.config)
        to_exp.persist_state()

        IO.copy_file(self.exp.model_dir / 'scores.tsv',
                     to_exp.model_dir / 'scores.tsv')
        if (self.exp.work_dir / 'rtg.zip').exists():
            IO.copy_file(self.exp.work_dir / 'rtg.zip',
                         to_exp.work_dir / 'rtg.zip')

        src_chkpt = chkpt_state
        log.warning(
            "step number, training loss and validation loss are not recalculated."
        )
        step_num, train_loss, val_loss = [
            src_chkpt.get(n, -1) for n in ['step', 'train_loss', 'val_loss']
        ]
        copy_fields = [
            'optim_state', 'step', 'train_loss', 'valid_loss', 'time',
            'rtg_version', 'model_type', 'model_args'
        ]
        state = dict((c, src_chkpt[c]) for c in copy_fields if c in src_chkpt)
        state['model_state'] = model.state_dict()
        state['averaged_time'] = time.time()
        state['model_paths'] = model_paths
        state['num_checkpts'] = len(model_paths)
        prefix = f'model_{name}_avg{len(model_paths)}'
        to_exp.store_model(step_num,
                           state,
                           train_score=train_loss,
                           val_score=val_loss,
                           keep=10,
                           prefix=prefix)
        chkpts = [mp.name for mp in model_paths]
        status = {
            'parent': str(self.exp.work_dir),
            'ensemble': ensemble,
            'checkpts': chkpts,
            'when': datetime.datetime.now().isoformat(),
            'who': os.environ.get('USER', '<unknown>'),
        }
        yaml.dump(status, stream=to_exp.work_dir / '_EXPORTED')

        if self.exp._trained_flag.exists():
            IO.copy_file(self.exp._trained_flag, to_exp._trained_flag)
Exemple #17
0
    def inherit_parent(self):
        parent = self.config['parent']
        parent_exp = TranslationExperiment(parent['experiment'],
                                           read_only=True)
        log.info(f"Parent experiment: {parent_exp.work_dir}")
        parent_exp.has_prepared()
        vocab_sepc = parent.get('vocab')
        if vocab_sepc:
            log.info(f"Parent vocabs inheritance spec: {vocab_sepc}")
            codec_lib = parent_exp.config['prep'].get('codec_lib')
            if codec_lib:
                self.config['prep']['codec_lib'] = codec_lib

            def _locate_field_file(exp: TranslationExperiment,
                                   name,
                                   check_exists=False) -> Path:
                switch = {
                    'src': exp._src_field_file,
                    'tgt': exp._tgt_field_file,
                    'shared': exp._shared_field_file
                }
                assert name in switch, f'{name} not allowed; valid options= {switch.keys()}'
                file = switch[name]
                if check_exists:
                    assert file.exists(
                    ), f'{file} doesnot exist; for {name} of {exp.work_dir}'
                return file

            for to_field, from_field in vocab_sepc.items():
                from_field_file = _locate_field_file(parent_exp,
                                                     from_field,
                                                     check_exists=True)
                to_field_file = _locate_field_file(self,
                                                   to_field,
                                                   check_exists=False)
                IO.copy_file(from_field_file, to_field_file)
            self.reload_vocabs()
        else:
            log.info("No vocabularies are inherited from parent")
        model_sepc = parent.get('model')
        if model_sepc:
            log.info("Parent model inheritance spec")
            if model_sepc.get('args'):
                self.model_args = parent_exp.model_args
            ensemble = model_sepc.get('ensemble', 1)
            model_paths = parent_exp.list_models(sort='step',
                                                 desc=True)[:ensemble]
            log.info(
                f"Averaging {len(model_paths)} checkpoints of parent model: \n{model_paths}"
            )
            from rtg.module.decoder import Decoder
            avg_state = Decoder.average_states(model_paths=model_paths)
            log.info(
                f"Saving parent model's state to {self.parent_model_state}")
            torch.save(avg_state, self.parent_model_state)

        shrink_spec = parent.get('shrink')
        if shrink_spec:
            remap_src, remap_tgt = self.shrink_vocabs()

            def map_rows(mapping: List[int], source: torch.Tensor, name=''):
                assert max(mapping) < len(source)
                target = torch.zeros((len(mapping), *source.shape[1:]),
                                     dtype=source.dtype,
                                     device=source.device)
                for new_idx, old_idx in enumerate(mapping):
                    target[new_idx] = source[old_idx]
                log.info(f"Mapped {name} {source.shape} --> {target.shape} ")
                return target

            """ src_embed.0.lut.weight [N x d]
                tgt_embed.0.lut.weight [N x d]
                generator.proj.weight [N x d]
                generator.proj.bias [N] """
            if remap_src:
                key = 'src_embed.0.lut.weight'
                avg_state[key] = map_rows(remap_src, avg_state[key], name=key)
            if remap_tgt:
                map_keys = [
                    'tgt_embed.0.lut.weight', 'generator.proj.weight',
                    'generator.proj.bias'
                ]
                for key in map_keys:
                    if key not in avg_state:
                        log.warning(
                            f'{key} not found in avg_state of parent model. Mapping skipped'
                        )
                        continue
                    avg_state[key] = map_rows(remap_tgt,
                                              avg_state[key],
                                              name=key)
            if self.parent_model_state.exists():
                self.parent_model_state.rename(
                    self.parent_model_state.with_suffix('.orig'))
            torch.save(avg_state, self.parent_model_state)
            self.persist_state(
            )  # this will fix src_vocab and tgt_vocab of model_args conf
Exemple #18
0
    def export(self,
               target: Path,
               name: str = None,
               ensemble: int = 1,
               copy_config=True,
               copy_vocab=True):
        to_exp = Experiment(target, config=self.exp.config)

        if copy_config:
            log.info("Copying config")
            to_exp.persist_state()

        if copy_vocab:
            log.info("Copying vocabulary")
            self.exp.copy_vocabs(to_exp)

        if ensemble > 0:
            assert name
            assert len(name.split()) == 1
            log.info("Going to average models and then copy")
            model_paths = self.exp.list_models()[:ensemble]
            log.info(f'Model paths: {model_paths}')
            checkpts = [torch.load(mp) for mp in model_paths]
            states = [chkpt['model_state'] for chkpt in checkpts]

            log.info("Averaging them ...")
            avg_state = Decoder.average_states(*states)
            chkpt_state = dict(model_state=avg_state,
                               model_type=checkpts[0]['model_type'],
                               model_args=checkpts[0]['model_args'])
            log.info("Instantiating it ...")
            model = instantiate_model(checkpt_state=chkpt_state, exp=self.exp)
            log.info(f"Exporting to {target}")
            to_exp = Experiment(target, config=self.exp.config)
            to_exp.persist_state()

            src_chkpt = checkpts[0]
            log.warning(
                "step number, training loss and validation loss are not recalculated."
            )
            step_num, train_loss, val_loss = [
                src_chkpt.get(n, -1)
                for n in ['step', 'train_loss', 'val_loss']
            ]
            copy_fields = [
                'optim_state', 'step', 'train_loss', 'valid_loss', 'time',
                'rtg_version', 'model_type', 'model_args'
            ]
            state = dict(
                (c, src_chkpt[c]) for c in copy_fields if c in src_chkpt)
            state['model_state'] = model.state_dict()
            state['averaged_time'] = time.time()
            state['model_paths'] = model_paths
            state['num_checkpts'] = len(model_paths)
            prefix = f'model_{name}_avg{len(model_paths)}'
            to_exp.store_model(step_num,
                               state,
                               train_score=train_loss,
                               val_score=val_loss,
                               keep=10,
                               prefix=prefix)