Ejemplo n.º 1
0
    def _train_epoches(self,
                       train_set,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_set=None):

        log = self.logger

        las_print_loss_total = 0  # Reset every print_every
        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # ******************** [loop over epochs] ********************
        for epoch in range(start_epoch, n_epochs + 1):

            for param_group in self.optimizer.optimizer.param_groups:
                log.info('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # ----------construct batches-----------
            log.info('--- construct train set ---')
            train_set.construct_batches(is_train=True)
            if dev_set is not None:
                log.info('--- construct dev set ---')
                dev_set.construct_batches(is_train=True)

            # --------print info for each epoch----------
            steps_per_epoch = len(train_set.iter_loader)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.debug(" --------- Epoch: %d, Step: %d ---------" %
                      (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            log.info('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # ******************** [loop over batches] ********************
            model.train(True)
            trainiter = iter(train_set.iter_loader)
            for idx in range(steps_per_epoch):

                # load batch items
                batch_items = trainiter.next()

                # update macro count
                step += 1
                step_elapsed += 1

                # Get loss
                losses = self._train_batch(model, batch_items, train_set, step,
                                           total_steps)

                las_loss = losses['las_loss']
                las_print_loss_total += las_loss

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    las_print_loss_avg = las_print_loss_total / self.print_every
                    las_print_loss_total = 0

                    log_msg = 'Progress: %d%%, Train las: %.4f'\
                     % (step / total_steps * 100, las_print_loss_avg)

                    log.info(log_msg)
                    self.writer.add_scalar('train_las_loss',
                                           las_print_loss_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        dev_accs, dev_losses = self._evaluate_batches(
                            model, dev_set)
                        las_loss = dev_losses['las_loss']
                        las_acc = dev_accs['las_acc']
                        log_msg = 'Progress: %d%%, Dev las loss: %.4f, accuracy: %.4f'\
                         % (step / total_steps * 100, las_loss, las_acc)
                        log.info(log_msg)
                        self.writer.add_scalar('dev_las_loss',
                                               las_loss,
                                               global_step=step)
                        self.writer.add_scalar('dev_las_acc',
                                               las_acc,
                                               global_step=step)

                        accuracy = las_acc
                        # save
                        if prev_acc < accuracy:
                            # save best model
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_src)

                            saved_path = ckpt.save(self.expt_dir)
                            log.info('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > self.max_count_no_improve:
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > self.max_count_num_rollback:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                log.info('reducing lr ...')
                                log.info('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr < 0.125 * self.learning_rate:
                                log.info('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is None:
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)
                        ckpt.rm_old(self.expt_dir, keep_num=self.keep_num)
                        log.info('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))

                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_src)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    log.info('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue

            # break nested for loop
            break
    def _train_epoches(self,
                       train_sets,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_sets=None):

        # load datasets
        train_set_asr = train_sets['asr']
        dev_set_asr = dev_sets['asr']
        train_set_mt = train_sets['mt']
        dev_set_mt = dev_sets['mt']

        log = self.logger

        print_loss_ae_total = 0  # Reset every print_every
        print_loss_asr_total = 0
        print_loss_mt_total = 0
        print_loss_kl_total = 0
        print_loss_l2_total = 0

        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        prev_bleu = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # loop over epochs
        for epoch in range(start_epoch, n_epochs + 1):

            # update lr
            if self.lr_warmup_steps != 0:
                self.optimizer.optimizer = self.lr_scheduler(
                    self.optimizer.optimizer,
                    step,
                    init_lr=self.learning_rate_init,
                    peak_lr=self.learning_rate,
                    warmup_steps=self.lr_warmup_steps)
            # print lr
            for param_group in self.optimizer.optimizer.param_groups:
                log.info('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # construct batches - allow re-shuffling of data
            log.info('--- construct train set ---')
            train_set_asr.construct_batches(is_train=True)
            train_set_mt.construct_batches(is_train=True)
            if dev_set_asr is not None:
                log.info('--- construct dev set ---')
                dev_set_asr.construct_batches(is_train=False)
                dev_set_mt.construct_batches(is_train=False)

            # print info
            steps_per_epoch_asr = len(train_set_asr.iter_loader)
            steps_per_epoch_mt = len(train_set_mt.iter_loader)
            steps_per_epoch = min(steps_per_epoch_asr, steps_per_epoch_mt)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.info(" ---------- Epoch: %d, Step: %d ----------" %
                     (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            log.info('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # loop over batches
            model.train(True)
            trainiter_asr = iter(train_set_asr.iter_loader)
            trainiter_mt = iter(train_set_mt.iter_loader)
            for idx in range(steps_per_epoch):

                # load batch items
                batch_items_asr = trainiter_asr.next()
                batch_items_mt = trainiter_mt.next()

                # update macro count
                step += 1
                step_elapsed += 1

                if self.lr_warmup_steps != 0:
                    self.optimizer.optimizer = self.lr_scheduler(
                        self.optimizer.optimizer,
                        step,
                        init_lr=self.learning_rate_init,
                        peak_lr=self.learning_rate,
                        warmup_steps=self.lr_warmup_steps)

                # Get loss
                losses = self._train_batch(model, batch_items_asr,
                                           batch_items_mt, step, total_steps)
                loss_ae = losses['nll_loss_ae']
                loss_asr = losses['nll_loss_asr']
                loss_mt = losses['nll_loss_mt']
                loss_kl = losses['kl_loss']
                loss_l2 = losses['l2_loss']

                print_loss_ae_total += loss_ae
                print_loss_asr_total += loss_asr
                print_loss_mt_total += loss_mt
                print_loss_kl_total += loss_kl
                print_loss_l2_total += loss_l2

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_ae_avg = print_loss_ae_total / self.print_every
                    print_loss_ae_total = 0
                    print_loss_asr_avg = print_loss_asr_total / self.print_every
                    print_loss_asr_total = 0
                    print_loss_mt_avg = print_loss_mt_total / self.print_every
                    print_loss_mt_total = 0
                    print_loss_kl_avg = print_loss_kl_total / self.print_every
                    print_loss_kl_total = 0
                    print_loss_l2_avg = print_loss_l2_total / self.print_every
                    print_loss_l2_total = 0

                    log_msg = 'Progress: %d%%, Train nlll_ae: %.4f, nlll_asr: %.4f, ' % (
                        step / total_steps * 100, print_loss_ae_avg,
                        print_loss_asr_avg)
                    log_msg += 'Train nlll_mt: %.4f, l2: %.4f, kl_en: %.4f' % (
                        print_loss_mt_avg, print_loss_l2_avg,
                        print_loss_kl_avg)
                    log.info(log_msg)

                    self.writer.add_scalar('train_loss_ae',
                                           print_loss_ae_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_asr',
                                           print_loss_asr_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_mt',
                                           print_loss_mt_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_kl',
                                           print_loss_kl_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_l2',
                                           print_loss_l2_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set_asr is not None:
                        losses, metrics = self._evaluate_batches(
                            model, dev_set_asr, dev_set_mt)

                        loss_kl = losses['kl_loss']
                        loss_l2 = losses['l2_loss']
                        loss_ae = losses['nll_loss_ae']
                        accuracy_ae = metrics['accuracy_ae']
                        bleu_ae = metrics['bleu_ae']
                        loss_asr = losses['nll_loss_asr']
                        accuracy_asr = metrics['accuracy_asr']
                        bleu_asr = metrics['bleu_asr']
                        loss_mt = losses['nll_loss_mt']
                        accuracy_mt = metrics['accuracy_mt']
                        bleu_mt = metrics['bleu_mt']

                        log_msg = 'Progress: %d%%, Dev AE loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_ae, accuracy_ae,
                            bleu_ae)
                        log.info(log_msg)
                        log_msg = 'Progress: %d%%, Dev ASR loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_asr, accuracy_asr,
                            bleu_asr)
                        log.info(log_msg)
                        log_msg = 'Progress: %d%%, Dev MT loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_mt, accuracy_mt,
                            bleu_mt)
                        log.info(log_msg)
                        log_msg = 'Progress: %d%%, Dev En KL loss: %.4f, L2 loss: %.4f' % (
                            step / total_steps * 100, loss_kl, loss_l2)
                        log.info(log_msg)

                        self.writer.add_scalar('dev_loss_l2',
                                               loss_l2,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_kl',
                                               loss_kl,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_ae',
                                               loss_ae,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_ae',
                                               accuracy_ae,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_ae',
                                               bleu_ae,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_asr',
                                               loss_asr,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_asr',
                                               accuracy_asr,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_asr',
                                               bleu_asr,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_mt',
                                               loss_mt,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_mt',
                                               accuracy_mt,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_mt',
                                               bleu_mt,
                                               global_step=step)

                        # save - use ASR res
                        accuracy_ave = (accuracy_asr / 4.0 + accuracy_mt) / 2.0
                        bleu_ave = (bleu_asr / 4.0 + bleu_mt) / 2.0
                        if ((prev_acc < accuracy_ave) and
                            (bleu_ave < 0.1)) or prev_bleu < bleu_ave:

                            # save best model - using bleu as metric
                            ckpt = Checkpoint(
                                model=model,
                                optimizer=self.optimizer,
                                epoch=epoch,
                                step=step,
                                input_vocab=train_set_asr.vocab_src,
                                output_vocab=train_set_asr.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            log.info('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy_ave
                            prev_bleu = bleu_ave
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > self.max_count_no_improve:
                            # break after self.max_count_no_improve epochs
                            if self.max_count_num_rollback == 0:
                                break
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > self.max_count_num_rollback:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                log.info('reducing lr ...')
                                log.info('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr <= 0.125 * self.learning_rate:
                                log.info('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is not None:
                            ckpt.rm_old(self.expt_dir, keep_num=self.keep_num)
                        log.info('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))

                    sys.stdout.flush()

            else:
                if dev_set_asr is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set_asr.vocab_src,
                                      output_vocab=train_set_asr.vocab_tgt)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    log.info('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue

            # break nested for loop
            break
Ejemplo n.º 3
0
    def _train_epochs(self,
                      train_set,
                      model,
                      n_epochs,
                      start_epoch,
                      start_step,
                      dev_set=None):

        log = self.logger

        print_loss_total = 0  # Reset every print_every
        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        prev_bleu = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # loop over epochs
        for epoch in range(start_epoch, n_epochs + 1):

            # update lr
            if self.lr_warmup_steps != 0:
                self.optimizer.optimizer = self.lr_scheduler(
                    self.optimizer.optimizer,
                    step,
                    init_lr=self.learning_rate_init,
                    peak_lr=self.learning_rate,
                    warmup_steps=self.lr_warmup_steps)

            # print lr
            for param_group in self.optimizer.optimizer.param_groups:
                log.info('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # construct batches - allow re-shuffling of data
            log.info('--- construct train set ---')
            train_set.construct_batches(is_train=True)
            if dev_set is not None:
                log.info('--- construct dev set ---')
                dev_set.construct_batches(is_train=False)

            # print info
            steps_per_epoch = len(train_set.iter_loader)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.info(" ---------- Epoch: %d, Step: %d ----------" %
                     (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            log.info('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # loop over batches
            model.train(True)
            trainiter = iter(train_set.iter_loader)
            for idx in range(steps_per_epoch):

                # load batch items
                batch_items = trainiter.next()

                # update macro count
                step += 1
                step_elapsed += 1

                if self.lr_warmup_steps != 0:
                    self.optimizer.optimizer = self.lr_scheduler(
                        self.optimizer.optimizer,
                        step,
                        init_lr=self.learning_rate_init,
                        peak_lr=self.learning_rate,
                        warmup_steps=self.lr_warmup_steps)

                # Get loss
                loss = self._train_batch(model, batch_items, train_set, step,
                                         total_steps)
                print_loss_total += loss

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_avg = print_loss_total / self.print_every
                    print_loss_total = 0

                    log_msg = 'Progress: %d%%, Train nlll: %.4f' % (
                        step / total_steps * 100, print_loss_avg)

                    log.info(log_msg)
                    self.writer.add_scalar('train_loss',
                                           print_loss_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        losses, metrics = self._evaluate_batches(
                            model, dev_set)

                        loss = losses['nll_loss']
                        accuracy = metrics['accuracy']
                        bleu = metrics['bleu']
                        log_msg = 'Progress: %d%%, Dev loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss, accuracy, bleu)
                        log.info(log_msg)
                        self.writer.add_scalar('dev_loss',
                                               loss,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc',
                                               accuracy,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu',
                                               bleu,
                                               global_step=step)

                        # save condition
                        cond_acc = (prev_acc <= accuracy)
                        cond_bleu = (((prev_acc <= accuracy) and (bleu < 0.1))
                                     or prev_bleu <= bleu)

                        # save
                        if self.eval_metric == 'tokacc':
                            save_cond = cond_acc
                        elif self.eval_metric == 'bleu':
                            save_cond = cond_bleu
                        if save_cond:
                            # save best model
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            log.info('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy
                            prev_bleu = bleu
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > self.max_count_no_improve:
                            # no roll back - break after self.max_count_no_improve epochs
                            if self.max_count_num_rollback == 0:
                                break
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > self.max_count_num_rollback:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                log.info('reducing lr ...')
                                log.info('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr <= 0.125 * self.learning_rate:
                                log.info('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is None:
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)
                        ckpt.rm_old(self.expt_dir, keep_num=self.keep_num)
                        log.info('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))

                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_tgt)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    log.info('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue

            # break nested for loop
            break
Ejemplo n.º 4
0
    def _train_epoches(self,
                       train_set,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_set=None):

        log = self.logger

        print_loss_total = 0  # Reset every print_every
        epoch_loss_total = 0  # Reset every epoch
        att_print_loss_total = 0  # Reset every print_every
        att_epoch_loss_total = 0  # Reset every epoch
        attcls_print_loss_total = 0  # Reset every print_every
        attcls_epoch_loss_total = 0  # Reset every epoch

        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # ******************** [loop over epochs] ********************
        for epoch in range(start_epoch, n_epochs + 1):

            for param_group in self.optimizer.optimizer.param_groups:
                print('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # ----------construct batches-----------
            # allow re-shuffling of data
            if type(train_set.attkey_path) == type(None):
                print('--- construct train set ---')
                train_batches, vocab_size = train_set.construct_batches(
                    is_train=True)
                if dev_set is not None:
                    print('--- construct dev set ---')
                    dev_batches, vocab_size = dev_set.construct_batches(
                        is_train=False)
            else:
                print('--- construct train set ---')
                train_batches, vocab_size = train_set.construct_batches_with_ddfd_prob(
                    is_train=True)
                if dev_set is not None:
                    print('--- construct dev set ---')
                    assert type(dev_set.attkey_path) != type(
                        None), 'Dev set missing ddfd probabilities'
                    dev_batches, vocab_size = dev_set.construct_batches_with_ddfd_prob(
                        is_train=False)

            # --------print info for each epoch----------
            steps_per_epoch = len(train_batches)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.debug(
                " ----------------- Epoch: %d, Step: %d -----------------" %
                (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            print('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # ******************** [loop over batches] ********************
            model.train(True)
            for batch in train_batches:

                # update macro count
                step += 1
                step_elapsed += 1

                # load data
                src_ids = batch['src_word_ids']
                src_lengths = batch['src_sentence_lengths']
                tgt_ids = batch['tgt_word_ids']
                tgt_lengths = batch['tgt_sentence_lengths']

                src_probs = None
                src_labs = None
                if 'src_ddfd_probs' in batch and model.additional_key_size > 0:
                    src_probs = batch['src_ddfd_probs']
                    src_probs = _convert_to_tensor(src_probs,
                                                   self.use_gpu).unsqueeze(2)
                if 'src_ddfd_labs' in batch:
                    src_labs = batch['src_ddfd_labs']
                    src_labs = _convert_to_tensor(src_labs,
                                                  self.use_gpu).unsqueeze(2)

                # sanity check src-tgt pair
                if step == 1:
                    print('--- Check src tgt pair ---')
                    log_msgs = check_srctgt(src_ids, tgt_ids,
                                            train_set.src_id2word,
                                            train_set.tgt_id2word)
                    for log_msg in log_msgs:
                        sys.stdout.buffer.write(log_msg)

                # convert variable to tensor
                src_ids = _convert_to_tensor(src_ids, self.use_gpu)
                tgt_ids = _convert_to_tensor(tgt_ids, self.use_gpu)

                # Get loss
                loss, att_loss, attcls_loss = self._train_batch(
                    src_ids,
                    tgt_ids,
                    model,
                    step,
                    total_steps,
                    src_probs=src_probs,
                    src_labs=src_labs)

                print_loss_total += loss
                epoch_loss_total += loss
                att_print_loss_total += att_loss
                att_epoch_loss_total += att_loss
                attcls_print_loss_total += attcls_loss
                attcls_epoch_loss_total += attcls_loss

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_avg = print_loss_total / self.print_every
                    att_print_loss_avg = att_print_loss_total / self.print_every
                    attcls_print_loss_avg = attcls_print_loss_total / self.print_every
                    print_loss_total = 0
                    att_print_loss_total = 0
                    attcls_print_loss_total = 0

                    log_msg = 'Progress: %d%%, Train nlll: %.4f, att: %.4f, attcls: %.4f' % (
                        step / total_steps * 100, print_loss_avg,
                        att_print_loss_avg, attcls_print_loss_avg)
                    # print(log_msg)
                    log.info(log_msg)
                    self.writer.add_scalar('train_loss',
                                           print_loss_avg,
                                           global_step=step)
                    self.writer.add_scalar('att_train_loss',
                                           att_print_loss_avg,
                                           global_step=step)
                    self.writer.add_scalar('attcls_train_loss',
                                           attcls_print_loss_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        dev_loss, accuracy, dev_attlosses = \
                         self._evaluate_batches(model, dev_batches, dev_set)
                        dev_attloss = dev_attlosses['att_loss']
                        dev_attclsloss = dev_attlosses['attcls_loss']
                        log_msg = 'Progress: %d%%, Dev loss: %.4f, accuracy: %.4f, att: %.4f, attcls: %.4f' % (
                            step / total_steps * 100, dev_loss, accuracy,
                            dev_attloss, dev_attclsloss)
                        log.info(log_msg)
                        self.writer.add_scalar('dev_loss',
                                               dev_loss,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc',
                                               accuracy,
                                               global_step=step)
                        self.writer.add_scalar('att_dev_loss',
                                               dev_attloss,
                                               global_step=step)
                        self.writer.add_scalar('attcls_dev_loss',
                                               dev_attclsloss,
                                               global_step=step)

                        # save
                        if prev_acc < accuracy:
                            # save best model
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            print('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > MAX_COUNT_NO_IMPROVE:
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                print(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim\
                                 .__class__(model.parameters(), **defaults)
                                # start_epoch = resume_checkpoint.epoch
                                # step = resume_checkpoint.step

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > MAX_COUNT_NUM_ROLLBACK:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                print(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim\
                                 .__class__(model.parameters(), **defaults)
                                start_epoch = resume_checkpoint.epoch
                                step = resume_checkpoint.step

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                print('reducing lr ...')
                                print('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr < 0.000125:
                                print('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is None:
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)
                            saved_path = ckpt.save(self.expt_dir)
                        ckpt.rm_old(self.expt_dir, keep_num=KEEP_NUM)
                        print('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))
                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_tgt)
                    # saved_path = ckpt.save(self.expt_dir)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    print('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue
            # break nested for loop
            break

            if step_elapsed == 0: continue
            epoch_loss_avg = epoch_loss_total / min(steps_per_epoch,
                                                    step - start_step)
            epoch_loss_total = 0
            log_msg = "Finished epoch %d: Train %s: %.4f" % (
                epoch, self.loss.name, epoch_loss_avg)

            log.info('\n')
            log.info(log_msg)