Exemple #1
0
def main():
    # No grads required
    torch.set_grad_enabled(False)
    args = parse_args()
    gen_args = {}
    exp = Experiment(args.pop('work_dir'), read_only=True)
    validate_args(args, exp)

    if exp.model_type == 'binmt':
        if not args.get('path'):
            Exception('--binmt-path argument is needed for BiNMT model.')
        gen_args['path'] = args.pop('binmt_path')

    weights = args.get('weights')
    if weights:
        decoder = Decoder.combo_new(exp,
                                    model_paths=args.pop('model_path'),
                                    weights=weights)
    else:
        decoder = Decoder.new(exp,
                              gen_args=gen_args,
                              model_paths=args.pop('model_path', None),
                              ensemble=args.pop('ensemble', 1))
    if args.pop('interactive'):
        if weights:
            log.warning(
                "Interactive shell not reloadable for combo mode. FIXME: TODO:"
            )
        if args['input'] != sys.stdin or args['output'] != sys.stdout:
            log.warning(
                '--input and --output args are not applicable in --interactive mode'
            )
        args.pop('input')
        args.pop('output')

        while True:
            try:
                # an hacky way to unload and reload model when user tries to switch models
                decoder.decode_interactive(**args)
                break  # exit loop if there is no request for reload
            except ReloadEvent as re:
                decoder = Decoder.new(exp,
                                      gen_args=gen_args,
                                      model_paths=re.model_paths)
                args = re.state
                # go back to loop and redo interactive shell
    else:
        return decoder.decode_file(args.pop('input'), args.pop('output'),
                                   **args)
Exemple #2
0
def __test_seq2seq_model__():
    """
        batch_size = 4
        p = '/Users/tg/work/me/rtg/saral/runs/1S-rnn-basic'
        exp = Experiment(p)
        steps = 3000
        check_pt = 100
        trainer = SteppedRNNNMTTrainer(exp=exp, lr=0.01, warmup_steps=100)
        trainer.train(steps=steps, check_point=check_pt, batch_size=batch_size)
    """
    from rtg.dummy import DummyExperiment
    from rtg.module.decoder import Decoder

    vocab_size = 50
    batch_size = 30
    exp = DummyExperiment("tmp.work",
                          config={'model_type': 'seq'
                                  '2seq'},
                          read_only=True,
                          vocab_size=vocab_size)
    emb_size = 100
    model_dim = 100
    steps = 3000
    check_pt = 100

    assert 2 == Batch.bos_val
    src = tensor([[4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
                  [13, 12, 11, 10, 9, 8, 7, 6, 5, 4]])
    src_lens = tensor([src.size(1)] * src.size(0))

    for reverse in (False, ):
        # train two models;
        #  first, just copy the numbers, i.e. y = x
        #  second, reverse the numbers y=(V + reserved - x)
        log.info(f"====== REVERSE={reverse}; VOCAB={vocab_size}======")
        model, args = RNNMT.make_model('DummyA',
                                       'DummyB',
                                       vocab_size,
                                       vocab_size,
                                       attention='dot',
                                       emb_size=emb_size,
                                       hid_size=model_dim,
                                       n_layers=1)
        trainer = SteppedRNNMTTrainer(exp=exp,
                                      model=model,
                                      lr=0.01,
                                      warmup_steps=100)
        decr = Decoder.new(exp, model)

        def check_pt_callback(**args):
            res = decr.greedy_decode(src, src_lens, max_len=17)
            for score, seq in res:
                log.info(f'{score:.4f} :: {seq}')

        trainer.train(steps=steps,
                      check_point=check_pt,
                      batch_size=batch_size,
                      check_pt_callback=check_pt_callback)
Exemple #3
0
def __test_model__():
    from rtg.dummy import DummyExperiment
    vocab_size = 30
    args = {
        'src_vocab': vocab_size,
        'tgt_vocab': vocab_size,
        'enc_layers': 0,
        'dec_layers': 4,
        'hid_size': 64,
        'ff_size': 64,
        'n_heads': 4,
        'activation': 'gelu'
    }
    if False:
        for n, p in model.named_parameters():
            print(n, p.shape)

    from rtg.module.decoder import Decoder

    config = {
        'model_type': 'tfmnmt',
        'trainer': {
            'init_args': {
                'chunk_size': 2
            }
        }
    }
    exp = DummyExperiment("work.tmp.t2t",
                          config=config,
                          read_only=True,
                          vocab_size=vocab_size)
    exp.model_args = args
    trainer = TransformerTrainer(exp=exp, warmup_steps=200)
    decr = Decoder.new(exp, trainer.model)

    assert 2 == Batch.bos_val
    src = tensor(
        [[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, Batch.eos_val, Batch.pad_value],
         [
             13, 12, 11, 10, 9, 8, 7, 6, Batch.eos_val, Batch.pad_value,
             Batch.pad_value, Batch.pad_value
         ]])
    src_lens = tensor([src.size(1)] * src.size(0))

    def check_pt_callback(**args):
        res = decr.greedy_decode(src, src_lens, max_len=12)
        for score, seq in res:
            log.info(f'{score:.4f} :: {seq}')

    batch_size = 50
    steps = 1000
    check_point = 50
    trainer.train(steps=steps,
                  check_point=check_point,
                  batch_size=batch_size,
                  check_pt_callback=check_pt_callback)
Exemple #4
0
def __test_model__():
    from rtg.data.dummy import DummyExperiment
    from rtg import Batch, my_tensor as tensor

    vocab_size = 24
    args = {
        'src_vocab': vocab_size,
        'tgt_vocab': vocab_size,
        'enc_layers': 0,
        'dec_layers': 4,
        'hid_size': 32,
        'eff_dims': [],
        'dff_dims': [64, 128, 128, 64],
        'enc_depth_probs': [],
        'dec_depth_probs': [1.0, 0.75, 0.5, 0.75],
        'n_heads': 4,
        'activation': 'relu'
    }

    from rtg.module.decoder import Decoder

    config = {
        'model_type': 'wvskptfmnmt',
        'trainer': {'init_args': {'chunk_size': 2, 'grad_accum': 5}},
        'optim': {
            'args': {
                # "cross_entropy", "smooth_kld", "binary_cross_entropy", "triplet_loss"
                'criterion': "smooth_kld",
                'lr': 0.01,
                'inv_sqrt': True
            }
        }
    }

    exp = DummyExperiment("work.tmp.wvskptfmnmt", config=config, read_only=True,
                          vocab_size=vocab_size)
    exp.model_args = args
    trainer = WVSKPTransformerTrainer(exp=exp, warmup_steps=200, **config['optim']['args'])
    decr = Decoder.new(exp, trainer.model)

    assert 2 == Batch.bos_val
    src = tensor([[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, Batch.eos_val, Batch.pad_value],
                  [13, 12, 11, 10, 9, 8, 7, 6, Batch.eos_val, Batch.pad_value, Batch.pad_value,
                   Batch.pad_value]])
    src_lens = tensor([src.size(1)] * src.size(0))

    def check_pt_callback(**args):
        res = decr.greedy_decode(src, src_lens, max_len=12)
        for score, seq in res:
            log.info(f'{score:.4f} :: {seq}')

    batch_size = 50
    steps = 200
    check_point = 10
    trainer.train(steps=steps, check_point=check_point, batch_size=batch_size,
                  check_pt_callback=check_pt_callback)
Exemple #5
0
def main():
    # No grads required
    torch.set_grad_enabled(False)
    args = parse_args()
    gen_args = {}
    exp = Experiment(args.pop('work_dir'), read_only=True)

    assert exp.model_type.endswith('lm'), 'Only for Language models'
    decoder = Decoder.new(exp, gen_args=gen_args, model_paths=args.pop('model_path', None),
                          ensemble=args.pop('ensemble', 1))

    log_pp = log_perplexity(decoder, args['test'])
    print(f'Log perplexity: {log_pp:g}')
    print(f'Perplexity: {math.exp(log_pp):g}')
Exemple #6
0
def main():
    # No grads required for decode
    torch.set_grad_enabled(False)
    cli_args = parse_args()
    exp = Experiment(cli_args.pop('exp_dir'), read_only=True)
    dec_args = exp.config.get('decoder') or exp.config['tester'].get('decoder', {})
    validate_args(cli_args, dec_args, exp)

    decoder = Decoder.new(exp, ensemble=dec_args.pop('ensemble', 1))
    for inp, out in zip(cli_args['input'], cli_args['output']):
        log.info(f"Decode :: {inp} -> {out}")
        try:
            if cli_args.get('no_buffer'):
                return decoder.decode_stream(inp, out, **dec_args)
            else:
                return decoder.decode_file(inp, out, **dec_args)
        except:
            log.exception(f"Decode failed for {inp}")
Exemple #7
0
def __test_model__():
    from rtg.dummy import DummyExperiment
    from rtg import Batch, my_tensor as tensor

    vocab_size = 24
    args = {
        'src_vocab': vocab_size,
        'tgt_vocab': vocab_size,
        'enc_layers': 4,
        'dec_layers': 3,
        'hid_size': 128,
        'ff_size': 256,
        'dec_rnn_type': 'GRU',
        'enc_heads': 4
    }

    from rtg.module.decoder import Decoder

    exp = DummyExperiment("work.tmp.hybridmt",
                          config={'model_type': 'hybridmt'},
                          read_only=True,
                          vocab_size=vocab_size)
    exp.model_args = args
    trainer = HybridMTTrainer(exp=exp, warmup_steps=200)
    decr = Decoder.new(exp, trainer.model)

    assert 2 == Batch.bos_val
    src = tensor([[4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
                  [13, 12, 11, 10, 9, 8, 7, 6, 5, 4]])
    src_lens = tensor([src.size(1)] * src.size(0))

    def check_pt_callback(**args):
        res = decr.greedy_decode(src, src_lens, max_len=12)
        for score, seq in res:
            log.info(f'{score:.4f} :: {seq}')

    batch_size = 50
    steps = 2000
    check_point = 50
    trainer.train(steps=steps,
                  check_point=check_point,
                  batch_size=batch_size,
                  check_pt_callback=check_pt_callback)
Exemple #8
0
def attach_translate_route(cli_args):
    global exp
    exp = Experiment(cli_args.pop("exp_dir"), read_only=True)
    dec_args = exp.config.get("decoder") or exp.config["tester"].get(
        "decoder", {})
    decoder = Decoder.new(exp, ensemble=dec_args.pop("ensemble", 1))
    dataprep = RtgIO(exp=exp)

    @bp.route("/translate", methods=["POST", "GET"])
    def translate():
        if request.method not in ("POST", "GET"):
            return "GET and POST are supported", 400
        if request.method == 'GET':
            sources = request.args.getlist("source", None)
        else:
            sources = (request.json or {}).get(
                'source', None) or request.form.getlist("source")
            if isinstance(sources, str):
                sources = [sources]
        if not sources:
            return "Please submit parameter 'source'", 400
        sources = [dataprep.pre_process(sent) for sent in sources]
        translations = []
        for source in sources:
            translated = decoder.decode_sentence(source, **dec_args)[0][1]
            translated = dataprep.post_process(translated.split())
            translations.append(translated)

        res = dict(source=sources, translation=translations)
        return jsonify(res)

    @bp.route("/conf.yml", methods=["GET"])
    def get_conf():
        conf_str = exp._config_file.read_text(encoding='utf-8',
                                              errors='ignore')
        return render_template('conf.yml.html', conf_str=conf_str)

    @bp.route("/about", methods=["GET"])
    def about():
        def_desc = "Model description is unavailable; please update conf.yml"
        return render_template('about.html',
                               model_desc=exp.config.get(
                                   "description", def_desc))
Exemple #9
0
def __test_binmt_model__():
    from rtg.module.decoder import Decoder

    vocab_size = 20
    exp = Experiment("tmp.work",
                     config={'model_type': 'binmt'},
                     read_only=True)
    num_epoch = 100
    emb_size = 100
    model_dim = 100
    batch_size = 32

    src = tensor([[2, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
                  [2, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4]])
    src_lens = tensor([src.size(1)] * src.size(0))

    for reverse in (False, ):
        # train two models;
        #  first, just copy the numbers, i.e. y = x
        #  second, reverse the numbers y=(V + reserved - x)
        log.info(f"====== REVERSE={reverse}; VOCAB={vocab_size}======")
        model, args = BiNMT.make_model('DummyA',
                                       'DummyB',
                                       vocab_size,
                                       vocab_size,
                                       emb_size=emb_size,
                                       hid_size=model_dim,
                                       n_layers=2)
        trainer = BiNmtTrainer(exp=exp,
                               model=model,
                               lr=0.01,
                               warmup_steps=500,
                               step_size=2 * batch_size)

        decr = Decoder.new(exp, model, gen_args={'path': 'E1D1'})
        assert 2 == Batch.bos_val

        def print_res(res):
            for score, seq in res:
                log.info(f'{score:.4f} :: {seq}')

        for epoch in range(num_epoch):
            model.train()
            train_data = BatchIterable(vocab_size,
                                       batch_size,
                                       50,
                                       seq_len=10,
                                       reverse=reverse,
                                       batch_first=True)
            val_data = BatchIterable(vocab_size,
                                     batch_size,
                                     5,
                                     reverse=reverse,
                                     batch_first=True)
            train_loss = trainer._run_cycle(train_data,
                                            train_data,
                                            train_mode=True)
            val_loss = trainer._run_cycle(val_data, val_data, train_mode=False)
            log.info(
                f"Epoch {epoch}, training Loss: {train_loss:g} \t validation loss:{val_loss:g}"
            )
            model.eval()
            res = decr.greedy_decode(src, src_lens, max_len=17)
            print_res(res)
Exemple #10
0
    def __init__(self,
                 exp: Experiment,
                 model: Optional[NMTModel] = None,
                 model_factory: Optional[Callable] = None,
                 optim: str = 'ADAM',
                 **optim_args):
        self.start_step = 0
        self.last_step = -1
        self.exp = exp
        optim_state = None
        if model:
            self.model = model
        else:
            args = exp.model_args
            assert args
            assert model_factory
            self.model, args = model_factory(exp=exp, **args)
            exp.model_args = args
            last_model, self.last_step = self.exp.get_last_saved_model()
            if last_model:
                self.start_step = self.last_step + 1
                log.info(
                    f"Resuming training from step:{self.start_step}, model={last_model}"
                )
                state = torch.load(last_model)
                model_state = state[
                    'model_state'] if 'model_state' in state else state
                if 'optim_state' in state:
                    optim_state = state['optim_state']
                self.model.load_state_dict(model_state)
            else:
                log.info(
                    "No earlier check point found. Looks like this is a fresh start"
                )

        # making optimizer
        optim_args['lr'] = optim_args.get('lr', 0.1)
        optim_args['betas'] = optim_args.get('betas', [0.9, 0.98])
        optim_args['eps'] = optim_args.get('eps', 1e-9)

        warmup_steps = optim_args.pop('warmup_steps', 8000)
        self._smoothing = optim_args.pop('label_smoothing', 0.1)
        constant = optim_args.pop('constant', 2)

        self.model = self.model.to(device)

        inner_opt = Optims[optim].new(self.model.parameters(), **optim_args)
        if optim_state:
            log.info("restoring optimizer state from checkpoint")
            try:
                inner_opt.load_state_dict(optim_state)
            except Exception:
                log.exception("Unable to restore optimizer, skipping it.")
        self.opt = NoamOpt(self.model.model_dim,
                           constant,
                           warmup_steps,
                           inner_opt,
                           step=self.start_step)

        optim_args.update(
            dict(warmup_steps=warmup_steps,
                 label_smoothing=self._smoothing,
                 constant=constant))
        if self.exp.read_only:
            self.tbd = NoOpSummaryWriter()
        else:
            self.tbd = SummaryWriter(log_dir=str(exp.work_dir / 'tensorboard'))

        self.exp.optim_args = optim, optim_args
        if not self.exp.read_only:
            self.exp.persist_state()
        self.samples = None
        if exp.samples_file.exists():
            with IO.reader(exp.samples_file) as f:
                self.samples = [line.strip().split('\t') for line in f]
                log.info(f"Found {len(self.samples)} sample records")
                if self.start_step == 0:
                    for samp_num, sample in enumerate(self.samples):
                        self.tbd.add_text(f"sample/{samp_num}",
                                          " || ".join(sample), 0)

            from rtg.module.decoder import Decoder
            self.decoder = Decoder.new(self.exp, self.model)

        if self.start_step == 0:
            self.init_embeddings()
        self.model = self.model.to(device)
Exemple #11
0
    def run_tests(self, exp=None, args=None):
        exp = exp or self.exp
        args = args or exp.config['tester']
        suit: Dict[str, List] = args['suit']
        assert suit
        log.info(f"Found {len(suit)} suit :: {suit.keys()}")

        _, step = exp.get_last_saved_model()
        if 'decoder' not in args:
            args['decoder'] = {}
        dec_args: Dict = args['decoder']
        best_params = copy.deepcopy(dec_args)
        max_len = best_params.get('max_len', 50)
        batch_size = best_params.get('batch_size', 20_000)
        # TODO: this has grown to become messy (trying to make backward compatible, improve the logic here
        if 'tune' in dec_args and not dec_args['tune'].get('tuned'):
            tune_args: Dict = dec_args['tune']
            prep_args = exp.config['prep']
            if 'tune_src' not in tune_args:
                tune_args['tune_src'] = prep_args['valid_src']
            if 'tune_ref' not in tune_args:
                tune_args['tune_ref'] = prep_args.get('valid_ref',
                                                      prep_args['valid_tgt'])
            best_params, tuner_args_ext = self.tune_decoder_params(
                exp=exp, max_len=max_len, batch_size=batch_size, **tune_args)
            log.info(f"tuner args = {tuner_args_ext}")
            log.info(f"Tuning complete: best_params: {best_params}")
            dec_args['tune'].update(
                tuner_args_ext)  # Update the config file with default args
            dec_args['tune']['tuned'] = True

        if 'tune' in best_params:
            del best_params['tune']

        log.info(f"params: {best_params}")
        beam_size = best_params.get('beam_size', 4)
        ensemble: int = best_params.pop('ensemble', 5)
        lp_alpha = best_params.get('lp_alpha', 0.0)
        eff_batch_size = batch_size // beam_size

        dec_args.update(
            dict(beam_size=beam_size,
                 lp_alpha=lp_alpha,
                 ensemble=ensemble,
                 max_len=max_len,
                 batch_size=batch_size))
        exp.persist_state()  # update the config

        assert step > 0, 'looks like no model is saved or invalid experiment dir'
        test_dir = exp.work_dir / f'test_step{step}_beam{beam_size}_ens{ensemble}_lp{lp_alpha}'
        log.info(f"Test Dir = {test_dir}")
        test_dir.mkdir(parents=True, exist_ok=True)

        decoder = Decoder.new(exp, ensemble=ensemble)
        for name, data in suit.items():
            # noinspection PyBroadException
            src, ref = data, None
            out_file = None
            if isinstance(data, list):
                src, ref = data[:2]
            elif isinstance(data, dict):
                src, ref = data['src'], data.get('ref')
                out_file = data.get('out')
            try:
                orig_src = Path(src).resolve()
                src_link = test_dir / f'{name}.src'
                ref_link = test_dir / f'{name}.ref'
                buffer = [(src_link, orig_src)]
                if ref:
                    orig_ref = Path(ref).resolve()
                    buffer.append((ref_link, orig_ref))
                for link, orig in buffer:
                    if not link.exists():
                        link.symlink_to(orig)
                out_file = test_dir / f'{name}.out.tsv' if not out_file else out_file
                out_file.parent.mkdir(parents=True, exist_ok=True)

                self.decode_eval_file(decoder,
                                      src_link,
                                      out_file,
                                      ref_link,
                                      batch_size=eff_batch_size,
                                      beam_size=beam_size,
                                      lp_alpha=lp_alpha,
                                      max_len=max_len)
            except Exception as e:
                log.exception(f"Something went wrong with '{name}' test")
                err = test_dir / f'{name}.err'
                err.write_text(str(e))
Exemple #12
0
    def tune_decoder_params(self,
                            exp: Experiment,
                            tune_src: str,
                            tune_ref: str,
                            batch_size: int,
                            trials: int = 10,
                            lowercase=True,
                            beam_size=(1, 4, 8),
                            ensemble=(1, 5, 10),
                            lp_alpha=(0.0, 0.4, 0.6),
                            suggested: List[Tuple[int, int, float]] = None,
                            **fixed_args):
        _, _, _, tune_args = inspect.getargvalues(inspect.currentframe())
        tune_args.update(fixed_args)
        ex_args = ['exp', 'self', 'fixed_args', 'batch_size', 'max_len']
        if trials == 0:
            ex_args += ['beam_size', 'ensemble', 'lp_alpha']
        for x in ex_args:
            del tune_args[x]  # exclude some args

        _, step = exp.get_last_saved_model()
        tune_dir = exp.work_dir / f'tune_step{step}'
        log.info(f"Tune dir = {tune_dir}")
        tune_dir.mkdir(parents=True, exist_ok=True)
        tune_src, tune_ref = Path(tune_src), Path(tune_ref)
        assert tune_src.exists()
        assert tune_ref.exists()
        tune_src, tune_ref = list(IO.get_lines(tune_src)), list(
            IO.get_lines(tune_ref))
        assert len(tune_src) == len(tune_ref)

        tune_log = tune_dir / 'scores.json'  # resume the tuning
        memory: Dict[Tuple, float] = {}
        if tune_log.exists():
            data = json.load(tune_log.open())
            # JSON keys cant be tuples, so they were stringified
            memory = {eval(k): v for k, v in data.items()}

        beam_sizes, ensembles, lp_alphas = [], [], []
        if suggested:
            if isinstance(suggested[0], str):
                suggested = [eval(x) for x in suggested]
            suggested = [(x[0], x[1], round(x[2], 2)) for x in suggested]
            suggested_new = [x for x in suggested if x not in memory]
            beam_sizes += [x[0] for x in suggested_new]
            ensembles += [x[1] for x in suggested_new]
            lp_alphas += [x[2] for x in suggested_new]

        new_trials = trials - len(memory)
        if new_trials > 0:
            beam_sizes += [random.choice(beam_size) for _ in range(new_trials)]
            ensembles += [random.choice(ensemble) for _ in range(new_trials)]
            lp_alphas += [
                round(random.choice(lp_alpha), 2) for _ in range(new_trials)
            ]

        # ensembling is somewhat costlier, so try minimize the model ensembling, by grouping them together
        grouped_ens = defaultdict(list)
        for b, ens, l in zip(beam_sizes, ensembles, lp_alphas):
            grouped_ens[ens].append((b, l))
        try:
            for ens, args in grouped_ens.items():
                decoder = Decoder.new(exp, ensemble=ens)
                for b_s, lp_a in args:
                    eff_batch_size = batch_size // b_s  # effective batch size
                    name = f'tune_step{step}_beam{b_s}_ens{ens}_lp{lp_a:.2f}'
                    log.info(name)
                    out_file = tune_dir / f'{name}.out.tsv'
                    score = self.decode_eval_file(decoder,
                                                  tune_src,
                                                  out_file,
                                                  tune_ref,
                                                  batch_size=eff_batch_size,
                                                  beam_size=b_s,
                                                  lp_alpha=lp_a,
                                                  lowercase=lowercase,
                                                  **fixed_args)
                    memory[(b_s, ens, lp_a)] = score
            best_params = sorted(memory.items(),
                                 key=lambda x: x[1],
                                 reverse=True)[0][0]
            return dict(zip(['beam_size', 'ensemble', 'lp_alpha'],
                            best_params)), tune_args
        finally:
            # JSON keys cant be tuples, so we stringify them
            data = {str(k): v for k, v in memory.items()}
            IO.write_lines(tune_log, json.dumps(data))
Exemple #13
0
    def __init__(self,
                 exp: Experiment,
                 model: Optional[NMTModel] = None,
                 model_factory: Optional[Callable] = None,
                 optim: str = 'ADAM',
                 **optim_args):
        self.last_step = -1
        self.exp = exp
        optim_state = None
        if model:
            self.model = model
        else:
            args = exp.model_args
            assert args
            assert model_factory
            self.model, args = model_factory(exp=exp, **args)
            exp.model_args = args
            last_model, self.last_step = self.exp.get_last_saved_model()
            if last_model:
                log.info(
                    f"Resuming training from step:{self.last_step}, model={last_model}"
                )
                state = torch.load(last_model, map_location=device)
                model_state = state[
                    'model_state'] if 'model_state' in state else state

                if 'optim_state' in state:
                    optim_state = state['optim_state']
                self.model.load_state_dict(model_state)
                if 'amp_state' in state and dtorch.fp16:
                    log.info("Restoring  AMP state")
                    dtorch._scaler.load_state_dict(state['amp_state'])
            else:
                log.info(
                    "No earlier check point found. Looks like this is a fresh start"
                )

        # optimizer : default args for missing fields
        for k, v in self.default_optim_args.items():
            optim_args[k] = optim_args.get(k, v)

        self.n_gpus = torch.cuda.device_count()
        self.device_ids = list(range(self.n_gpus))

        inner_opt_args = {
            k: optim_args[k]
            for k in ['lr', 'betas', 'eps', 'weight_decay', 'amsgrad']
        }

        self.core_model = self.model.to(device)

        trainable_params = self.exp.config['optim'].get('trainable', {})
        if trainable_params:
            if drtorch.is_distributed:  # model is wrapped in DP or DistributedDP
                log.warning(
                    f">> Using more than 1 GPU with 'trainable' params is NOT tested"
                )
            trainable_params = self.core_model.get_trainable_params(
                include=trainable_params.get('include'),
                exclude=trainable_params.get('exclude'))
        else:
            trainable_params = self.model.parameters()

        inner_opt = Optims[optim].new(trainable_params, **inner_opt_args)
        self.model = dtorch.maybe_distributed(self.core_model)

        if optim_state:
            log.info("restoring optimizer state from checkpoint")
            try:
                inner_opt.load_state_dict(optim_state)
            except Exception:
                log.exception("Unable to restore optimizer, skipping it.")
        self.opt = NoamOpt(self.core_model.model_dim,
                           optim_args['constant'],
                           optim_args['warmup_steps'],
                           inner_opt,
                           step=self.start_step,
                           inv_sqrt=optim_args['inv_sqrt'])

        if self.exp.read_only:
            self.tbd = NoOpSummaryWriter()
        else:
            self.tbd = SummaryWriter(log_dir=str(exp.work_dir / 'tensorboard'))

        self.exp.optim_args = optim, optim_args
        if not self.exp.read_only:
            self.exp.persist_state()
        self.samples = None
        if exp.samples_file and exp.samples_file.exists():
            with IO.reader(exp.samples_file) as f:
                self.samples = [line.strip().split('\t') for line in f]
                log.info(f"Found {len(self.samples)} sample records")
                if self.start_step == 0:
                    for samp_num, sample in enumerate(self.samples):
                        self.tbd.add_text(f"sample/{samp_num}",
                                          " || ".join(sample), 0)

            from rtg.module.decoder import Decoder
            self.decoder = Decoder.new(self.exp, self.core_model)

        if self.start_step <= 1:
            self.maybe_init_model()

        self.criterion = self.create_criterion(optim_args['criterion'])