示例#1
0
文件: rnnmt.py 项目: isi-nlp/rtg
 def run_valid_epoch(self, data_iter: BatchIterable) -> float:
     state = TrainerState(self.model, -1)
     with tqdm(data_iter, total=data_iter.num_batches, unit='batch',
               dynamic_ncols=True) as data_bar:
         for i, batch in enumerate(data_bar):
             batch = batch.to(device)
             # Step clear gradients
             self.model.zero_grad()
             # Step Run forward pass.
             outp_log_probs = self.model(batch)
             loss = self.loss_func(outp_log_probs, batch, train_mode=False)
             bar_msg, _ = state.step(batch.y_toks, loss)
             data_bar.set_postfix_str(bar_msg, refresh=False)
             del batch
     return state.running_loss()
示例#2
0
文件: rnnmt.py 项目: isi-nlp/rtg
    def train(self, steps: int, check_point: int, batch_size: int, fine_tune=False,
              check_pt_callback: Optional[Callable] = None, **args):
        log.info(f'Going to train for {steps} steps; batch_size={batch_size}; '
                 f'check point size:{check_point}; fine tune={fine_tune}')
        keep_models = args.get('keep_models', 4)  # keep last _ models and delete the old
        sort_by = args.get('sort_by', 'random')
        if steps <= self.start_step:
            raise Exception(f'The model was already trained to {self.start_step} steps. '
                            f'Please increase the steps or clear the existing models')
        train_data = self.exp.get_train_data(batch_size=batch_size, steps=steps - self.start_step,
                                             sort_by=sort_by, shuffle=True, batch_first=True,
                                             fine_tune=fine_tune)
        val_data = self.exp.get_val_data(batch_size, shuffle=False, batch_first=True,
                                         sort_desc=True)

        train_state = TrainerState(self.model, check_point=check_point)
        train_state.train_mode(True)
        unsaved_state = False
        with tqdm(train_data, initial=self.start_step, total=steps, unit='batch',
                  dynamic_ncols=True) as data_bar:
            for batch in data_bar:
                batch = batch.to(device)
                # Step clear gradients
                self.model.zero_grad()
                # Step Run forward pass.
                outp_log_probs = self.model(batch)

                loss = self.loss_func(outp_log_probs, batch, True)
                unsaved_state = True
                self.tbd.add_scalars('training', {'step_loss': loss,
                                                  'learn_rate': self.opt.curr_lr},
                                     self.opt.curr_step)
                bar_msg, is_check_pt = train_state.step(batch.y_toks, loss)
                bar_msg += f', LR={self.opt.curr_lr:g}'
                data_bar.set_postfix_str(bar_msg, refresh=False)

                del batch  # TODO: force free memory
                if is_check_pt:
                    train_loss = train_state.reset()
                    train_state.train_mode(False)
                    val_loss = self.run_valid_epoch(val_data)
                    self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models)
                    if check_pt_callback:
                        check_pt_callback(model=self.model,
                                          step=self.opt.curr_step,
                                          train_loss=train_loss)
                    train_state.train_mode(True)
                    unsaved_state = False

        if unsaved_state:
            # End of training
            train_loss = train_state.reset()
            train_state.train_mode(False)
            val_loss = self.run_valid_epoch(val_data)
            self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models)
示例#3
0
文件: tfmnmt.py 项目: MGheini/rtg
    def train(self,
              steps: int,
              check_point: int,
              batch_size: int,
              check_pt_callback: Optional[Callable] = None,
              fine_tune=False,
              dec_bos_cut=False,
              keep_models=10,
              sort_by='eq_len_rand_batch',
              log_interval: int = 10,
              **args):
        """

        :param steps: how many optimizer steps to train (also, means how many batches)
        :param check_point: after how many checkpoints to
        :param batch_size: how many target tokens in batch max ( = max_len * num_sentences)
        :param check_pt_callback: function to call back after checkpt
        :param fine_tune: should the fine tune corpus be used instead of training corpus
        :param dec_bos_cut: copy the first time step of input as decoder's BOS
        :param keep_models: how many to checkpts to keep
        :param args: any extra args
        :return:
        """
        log_resources = args.pop('log_resources', False)
        assert log_interval > 0
        if args:
            # no extra args. let user know if an extra arg is passed
            raise Exception(f" Found extra args: {args}")
        log.info(
            f'Going to train for {steps} steps (from {self.start_step} steps);'
            f' batch_size={batch_size} toks; sort_by={sort_by};'
            f' check point size:{check_point}; fine_tune={fine_tune};'
            f' dec_bos_cut={dec_bos_cut}')
        if self.n_gpus > 1:
            batch_size *= self.n_gpus
            log.info(
                f"# GPUs = {self.n_gpus}, batch_size is set to {batch_size}")

        if steps <= self.start_step:
            raise Exception(
                f'The model was already trained to {self.start_step} steps. '
                f'Please increase the steps or clear the existing models')
        train_data = self.exp.get_train_data(batch_size=batch_size,
                                             steps=steps - self.start_step,
                                             sort_by=sort_by,
                                             batch_first=True,
                                             fine_tune=fine_tune)
        val_data = self.exp.get_val_data(batch_size,
                                         shuffle=False,
                                         batch_first=True,
                                         sort_desc=False)

        train_state = TrainerState(self.model, check_point=check_point)
        train_state.train_mode(True)
        unsaved_state = False
        cuda_available = torch.cuda.is_available()
        with tqdm(train_data,
                  initial=self.start_step,
                  total=steps,
                  unit='batch',
                  dynamic_ncols=True) as data_bar:
            for batch in data_bar:
                self.model.zero_grad()
                batch = batch.to(device)
                num_toks = batch.y_toks
                x_seqs = batch.x_seqs
                if dec_bos_cut:
                    bos_step = x_seqs[:, :1]
                    x_seqs = x_seqs[:, 1:]
                else:
                    bos_step = torch.full((len(batch), 1),
                                          fill_value=Batch.bos_val,
                                          dtype=torch.long,
                                          device=device)
                x_mask = (x_seqs != batch.pad_value).unsqueeze(1)
                y_seqs_with_bos = torch.cat([bos_step, batch.y_seqs], dim=1)
                y_mask = Batch.make_target_mask(y_seqs_with_bos)
                out = self.model(x_seqs, y_seqs_with_bos, x_mask, y_mask)
                # [Batch x Time x D]
                # skip the last time step (the one with EOS as input)
                out = out[:, :-1, :]
                # assumption:  y_seqs has EOS, and not BOS
                loss = self.loss_func(out, batch.y_seqs, num_toks, True)
                unsaved_state = True
                if self.opt.curr_step % log_interval == 0:
                    self.tbd.add_scalars('training', {
                        'step_loss': loss,
                        'learn_rate': self.opt.curr_lr
                    }, self.opt.curr_step)
                    if log_resources and cuda_available:
                        self._log_resources(batch)

                progress_msg, is_check_pt = train_state.step(num_toks, loss)
                progress_msg += f', LR={self.opt.curr_lr:g}'

                data_bar.set_postfix_str(progress_msg, refresh=False)
                del batch

                if is_check_pt:
                    train_loss = train_state.reset()
                    train_state.train_mode(False)
                    val_loss = self.run_valid_epoch(val_data,
                                                    dec_bos_cut=dec_bos_cut)
                    self.make_check_point(train_loss,
                                          val_loss=val_loss,
                                          keep_models=keep_models)
                    if check_pt_callback:
                        check_pt_callback(model=self.model,
                                          step=self.opt.curr_step,
                                          train_loss=train_loss)
                    train_state.train_mode(True)
                    unsaved_state = False
                    gc.collect()

        # End of training
        if unsaved_state:
            train_loss = train_state.reset()
            train_state.train_mode(False)
            val_loss = self.run_valid_epoch(val_data, dec_bos_cut=dec_bos_cut)
            self.make_check_point(train_loss,
                                  val_loss=val_loss,
                                  keep_models=keep_models)
示例#4
0
文件: tfmlm.py 项目: MGheini/rtg
    def train(self,
              steps: int,
              check_point: int,
              batch_size: int,
              check_pt_callback: Optional[Callable] = None,
              keep_models=4,
              **args):
        log.info(
            f'Going to train for {steps} epochs; batch_size={batch_size}; '
            f'check point size:{check_point}')

        rem_steps = steps - self.start_step
        if rem_steps <= 0:
            raise Exception(
                f'The model was already trained to {self.start_step} steps. '
                f'Please increase the steps or clear the existing models')
        side = 'tgt'  # TODO: this should be inferrable or configurable instead of hardcoded

        train_data = self.exp.get_mono_data('train',
                                            side,
                                            batch_size=batch_size,
                                            batch_first=True,
                                            sort_dec=False,
                                            num_batches=rem_steps,
                                            shuffle=True)
        val_data = self.exp.get_mono_data('valid',
                                          side,
                                          batch_size=batch_size,
                                          batch_first=True,
                                          sort_dec=False)

        train_state = TrainerState(self.model, check_point=check_point)
        train_state.train_mode(True)
        unsaved_state = False
        with tqdm(train_data,
                  initial=self.start_step,
                  total=steps,
                  unit='batch',
                  dynamic_ncols=True) as data_bar:
            for batch in data_bar:
                self.model.zero_grad()
                assert batch.eos_x  # must have EOS
                batch = batch.to(device)
                num_toks = batch.x_toks
                seqs = batch.x_seqs
                bos_step = torch.full((len(batch), 1),
                                      fill_value=Batch.bos_val,
                                      dtype=torch.long,
                                      device=device)
                seqs_with_bos = torch.cat([bos_step, batch.x_seqs], dim=1)
                seq_mask = Batch.make_target_mask(seqs_with_bos)
                out = self.model(seqs_with_bos, seq_mask, gen_probs=False)
                # [Batch x Time x D]
                # skip the last time step (the one with EOS as input)
                out = out[:, :-1, :]
                # assumption:  y_seqs has EOS, and not BOS
                loss = self.loss_func(out, seqs, num_toks, True)
                unsaved_state = True
                self.tbd.add_scalars('training', {
                    'step_loss': loss,
                    'learn_rate': self.opt.curr_lr
                }, self.opt.curr_step)

                progress_msg, is_check_pt = train_state.step(num_toks, loss)
                progress_msg += f', LR={self.opt.curr_lr:g}'

                data_bar.set_postfix_str(progress_msg, refresh=False)
                del batch

                if is_check_pt:
                    train_loss = train_state.reset()
                    train_state.train_mode(False)
                    val_loss = self.run_valid_epoch(val_data)
                    self.make_check_point(train_loss,
                                          val_loss,
                                          keep_models=keep_models)
                    if check_pt_callback:
                        check_pt_callback(model=self.model,
                                          step=self.opt.curr_step,
                                          train_loss=train_loss)
                    train_state.train_mode(True)
                    unsaved_state = False

        if unsaved_state:
            # End of training
            train_loss = train_state.reset()
            train_state.train_mode(False)
            val_loss = self.run_valid_epoch(val_data)
            self.make_check_point(train_loss,
                                  val_loss,
                                  keep_models=keep_models)
示例#5
0
    def train(self, steps: int, check_point: int, batch_size: int,
              check_pt_callback: Optional[Callable] = None, **args):
        train_state = TrainerState(self.model, check_point=check_point)
        train_state.train_mode(True)
        if self.start_step >= steps:
            log.warning(f"Already trained to  {self.start_step}. Considering it as done.")
            return
        rem_steps = steps - self.start_step
        side = 'tgt'     # TODO: this should be inferrable or configurable instead of hardcoded

        train_data = self.exp.get_mono_data('train', side, batch_size=batch_size,
                                            batch_first=True, sort_dec=True,
                                            num_batches=rem_steps, shuffle=True)
        val_data = self.exp.get_mono_data('valid', side, batch_size=batch_size,
                                          batch_first=True, sort_dec=True)

        keep_models = 8
        unsaved_state = False
        with tqdm(train_data, initial=self.start_step, total=steps, unit='batch',
                  dynamic_ncols=True) as data_bar:
            for batch in data_bar:
                batch.to(device)
                outp_log_probs = self.model.batch_forward(batch)
                loss = self.simple_loss_func(outp_log_probs, seq_lens=batch.x_len,
                                             tot_toks=batch.x_toks, max_seq_len=batch.max_x_len,
                                             train_mode=True)
                unsaved_state = True
                bar_msg, is_check_pt = train_state.step(batch.x_toks, loss)
                data_bar.set_postfix_str(bar_msg, refresh=True)
                del batch       # TODO: force free memory
                if is_check_pt:
                    train_loss = train_state.reset()
                    train_state.train_mode(False)
                    val_loss = self.run_valid_epoch(val_data)
                    self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models)
                    if check_pt_callback:
                        check_pt_callback(model=self.model,
                                          step=self.opt.curr_step,
                                          train_loss=train_loss)
                    train_state.train_mode(True)
                    unsaved_state = False

        log.info("End of training session")
        if unsaved_state:
            # End of training
            train_loss = train_state.reset()
            train_state.train_mode(False)
            val_loss = self.run_valid_epoch(val_data)
            self.make_check_point(train_loss, val_loss=val_loss, keep_models=keep_models)
示例#6
0
    def train(self,
              steps: int,
              check_point: int,
              batch_size: int,
              check_pt_callback: Optional[Callable] = None,
              fine_tune=False,
              dec_bos_cut=False,
              keep_models=10,
              sort_by='eq_len_rand_batch',
              log_interval: int = 10,
              keep_in_mem=False,
              early_stop=None,
              **args):
        """
        :param steps: how many optimizer steps to train (also, means how many batches)
        :param check_point: after how many checkpoints to
        :param batch_size: how many target tokens in batch max ( = max_len * num_sentences)
        :param check_pt_callback: function to call back after checkpt
        :param fine_tune: should the fine tune corpus be used instead of training corpus
        :param dec_bos_cut: copy the first time step of input as decoder's BOS
        :param keep_models: how many checkpts to keep
        :param keep_in_mem: keep training data in memory
        :param early_stop: {patience: N validations, by: loss, enabled: True}
        :param args: any extra args
        :return:
        """
        log_resources = args.pop('log_resources', False)
        log_embedding = args.pop('log_embedding', False)
        split_ratio = args.pop('split_ratio', 0.)
        dynamic_epoch = args.pop('dynamic_epoch', False)
        assert log_interval > 0

        # Gradient accumulation
        opt_steps = steps
        batches = steps * self.grad_accum_interval
        start_batch = self.start_step * self.grad_accum_interval
        check_point = check_point * self.grad_accum_interval
        if isinstance(batch_size, int):
            max_toks, max_sents = batch_size, float('inf')
        else:
            max_toks, max_sents = batch_size
        if args:
            # no extra args. let user know if an extra arg is passed
            raise Exception(f" Found extra args: {args}")
        log.info(
            f'Going to train for {opt_steps} optimizer steps over {batches} batches'
            f' (from {self.start_step} steps);'
            f' batch_size={batch_size} toks; sort_by={sort_by};'
            f' check point size:{check_point}; fine_tune={fine_tune};'
            f' dec_bos_cut={dec_bos_cut}')

        distr = DistribTorch.instance()
        if batches <= start_batch:
            raise Exception(
                f'The model was already trained to {self.start_step} steps. '
                f'Please increase the steps or clear the existing models')

        train_data = self.exp.get_train_data(batch_size=batch_size,
                                             steps=batches - start_batch,
                                             sort_by=sort_by,
                                             batch_first=True,
                                             fine_tune=fine_tune,
                                             keep_in_mem=keep_in_mem,
                                             split_ratio=split_ratio,
                                             dynamic_epoch=dynamic_epoch)
        val_data = None
        if distr.is_global_main:
            val_data = self.exp.get_val_data(batch_size=max_toks,
                                             shuffle=False,
                                             batch_first=True,
                                             sort_desc=False)

        train_state = TrainerState(self.model, check_point=check_point)
        train_state.train_mode(True)
        unsaved_state = False
        cuda_available = torch.cuda.is_available()

        batch_count = -1
        stopper = None
        early_stopped = False  # or converged
        if early_stop:
            stopper = EarlyStopper(cur_step=self.start_step, **early_stop)

        with tqdm(train_data,
                  initial=start_batch,
                  total=batches,
                  unit='batch',
                  dynamic_ncols=True,
                  disable=not distr.is_global_main) as data_bar:
            for batch in data_bar:
                batch_count += 1
                take_step = (batch_count % self.grad_accum_interval) == 0

                # if update_interval == 0:
                #     self.model.zero_grad()

                #  if not dataparallel, then move
                if self.n_gpus <= 1:
                    batch = batch.to(device)
                num_toks = batch.y_toks
                x_seqs = batch.x_seqs
                if dec_bos_cut:
                    bos_step = x_seqs[:, :1]
                    x_seqs = x_seqs[:, 1:]
                else:
                    bos_step = torch.full((len(batch), 1),
                                          fill_value=batch.bos_val,
                                          dtype=torch.long,
                                          device=batch.y_seqs.device)

                # Prep masks
                x_mask = (x_seqs != batch.pad_val).unsqueeze(1)
                y_seqs_with_bos = torch.cat([bos_step, batch.y_seqs], dim=1)
                y_mask = batch.make_autoreg_mask(y_seqs_with_bos)

                with autocast(enabled=dtorch.fp16):
                    # [Batch x Time x D]
                    out = self.model(x_seqs, y_seqs_with_bos, x_mask, y_mask)

                    # skip the last time step (the one with EOS as input)
                    out = out[:, :-1, :]

                    # assumption:  y_seqs has EOS, and not BOS
                    loss = self.loss_func(out,
                                          batch.y_seqs,
                                          num_toks,
                                          train_mode=True,
                                          take_step=take_step)

                if stopper and take_step:
                    stopper.step()
                # Log
                unsaved_state = True
                if self.opt.curr_step % log_interval == 0:
                    self.tbd.add_scalars('training', {
                        'step_loss': loss,
                        'learn_rate': self.opt.curr_lr
                    }, self.opt.curr_step)
                    if log_resources and cuda_available:
                        self._log_resources(batch)

                progress_msg, is_check_pt = train_state.step(num_toks, loss)
                progress_msg += f', LR={self.opt.curr_lr:0.8f}'
                data_bar.set_postfix_str(progress_msg, refresh=False)
                del batch

                # Save checkpoint
                if is_check_pt:
                    train_loss = train_state.reset()
                    log.info(
                        f"Chkpt Train loss={train_loss}; Runs validation? {distr.is_global_main}"
                    )
                    if distr.is_global_main:
                        train_state.train_mode(False)
                        with torch.no_grad():
                            val_loss = self.run_valid_epoch(
                                val_data, dec_bos_cut=dec_bos_cut)
                            self.make_check_point(train_loss,
                                                  val_loss=val_loss,
                                                  keep_models=keep_models,
                                                  log_embedding=log_embedding)
                            if check_pt_callback:
                                check_pt_callback(model=self.model,
                                                  step=self.opt.curr_step,
                                                  train_loss=train_loss)
                        train_state.train_mode(True)

                        if stopper:
                            stopper.validation(val_loss)
                            if stopper.is_stop():
                                log.info(
                                    f"Stopping at {stopper.cur_step} because {stopper.by}"
                                    f" didnt improve over {stopper.patience} checkpoints"
                                )
                                early_stopped = True
                                break
                    unsaved_state = False
                    gc.collect()
                    distr.barrier()

        # End of training
        if unsaved_state and distr.is_global_main:
            train_loss = train_state.reset()
            train_state.train_mode(False)
            val_loss = self.run_valid_epoch(val_data, dec_bos_cut=dec_bos_cut)
            self.make_check_point(train_loss,
                                  val_loss=val_loss,
                                  keep_models=keep_models)

        distr.barrier()
        return early_stopped