Ejemplo n.º 1
0
    def _train_epoches(self,
                       train_sets,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_sets=None):

        # load relevant dataset
        train_set = train_sets['asr']
        dev_set = dev_sets['asr']

        log = self.logger

        print_loss_de_total = 0  # Reset every print_every
        print_loss_en_total = 0

        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        prev_bleu = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # loop over epochs
        for epoch in range(start_epoch, n_epochs + 1):

            # update lr
            if self.lr_warmup_steps != 0:
                self.optimizer.optimizer = self.lr_scheduler(
                    self.optimizer.optimizer,
                    step,
                    init_lr=self.learning_rate_init,
                    peak_lr=self.learning_rate,
                    warmup_steps=self.lr_warmup_steps)
            # print lr
            for param_group in self.optimizer.optimizer.param_groups:
                log.info('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # construct batches - allow re-shuffling of data
            log.info('--- construct train set ---')
            train_set.construct_batches(is_train=True)
            if dev_set is not None:
                log.info('--- construct dev set ---')
                dev_set.construct_batches(is_train=False)

            # print info
            steps_per_epoch = len(train_set.iter_loader)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.info(" ---------- Epoch: %d, Step: %d ----------" %
                     (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            log.info('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # loop over batches
            model.train(True)
            trainiter = iter(train_set.iter_loader)
            for idx in range(steps_per_epoch):

                # load batch items
                batch_items = trainiter.next()

                # update macro count
                step += 1
                step_elapsed += 1

                if self.lr_warmup_steps != 0:
                    self.optimizer.optimizer = self.lr_scheduler(
                        self.optimizer.optimizer,
                        step,
                        init_lr=self.learning_rate_init,
                        peak_lr=self.learning_rate,
                        warmup_steps=self.lr_warmup_steps)

                # Get loss
                losses = self._train_batch(model, batch_items, train_set, step,
                                           total_steps)
                loss_de = losses['nll_loss_de']
                loss_en = losses['nll_loss_en']

                print_loss_de_total += loss_de
                print_loss_en_total += loss_en

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_de_avg = print_loss_de_total / self.print_every
                    print_loss_de_total = 0
                    print_loss_en_avg = print_loss_en_total / self.print_every
                    print_loss_en_total = 0

                    log_msg = 'Progress: %d%%, Train nlll_de: %.4f, nlll_en: %.4f' % (
                        step / total_steps * 100, print_loss_de_avg,
                        print_loss_en_avg)
                    log.info(log_msg)
                    self.writer.add_scalar('train_loss_de',
                                           print_loss_de_avg,
                                           global_step=step)
                    self.writer.add_scalar('train_loss_en',
                                           print_loss_en_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        losses, metrics = self._evaluate_batches(
                            model, dev_set)

                        loss_de = losses['nll_loss_de']
                        accuracy_de = metrics['accuracy_de']
                        bleu_de = metrics['bleu_de']
                        loss_en = losses['nll_loss_en']
                        accuracy_en = metrics['accuracy_en']
                        bleu_en = metrics['bleu_en']

                        log_msg = 'Progress: %d%%, Dev DE loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_de, accuracy_de,
                            bleu_de)
                        log.info(log_msg)
                        log_msg = 'Progress: %d%%, Dev EN loss: %.4f, accuracy: %.4f, bleu: %.4f' % (
                            step / total_steps * 100, loss_en, accuracy_en,
                            bleu_en)
                        log.info(log_msg)

                        self.writer.add_scalar('dev_loss_de',
                                               loss_de,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_de',
                                               accuracy_de,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_de',
                                               bleu_de,
                                               global_step=step)
                        self.writer.add_scalar('dev_loss_en',
                                               loss_en,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc_en',
                                               accuracy_en,
                                               global_step=step)
                        self.writer.add_scalar('dev_bleu_en',
                                               bleu_en,
                                               global_step=step)

                        # save - use EN stats
                        # import pdb; pdb.set_trace()
                        # if prev_acc < accuracy:
                        if ((prev_acc < accuracy_en) and
                            (bleu_en < 0.1)) or prev_bleu < bleu_en:

                            # save best model - using bleu as metric
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            log.info('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy_en
                            prev_bleu = bleu_en
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > self.max_count_no_improve:
                            # break after self.max_count_no_improve epochs
                            if self.max_count_num_rollback == 0:
                                break
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > self.max_count_num_rollback:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                log.info(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim.__class__(
                                    model.parameters(), **defaults)

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                log.info('reducing lr ...')
                                log.info('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr <= 0.125 * self.learning_rate:
                                log.info('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is not None:
                            ckpt.rm_old(self.expt_dir, keep_num=self.keep_num)
                        log.info('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))

                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_tgt)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    log.info('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue

            # break nested for loop
            break
Ejemplo n.º 2
0
    def _train_epoches(self,
                       train_set,
                       model,
                       n_epochs,
                       start_epoch,
                       start_step,
                       dev_set=None):

        log = self.logger

        print_loss_total = 0  # Reset every print_every
        epoch_loss_total = 0  # Reset every epoch
        att_print_loss_total = 0  # Reset every print_every
        att_epoch_loss_total = 0  # Reset every epoch
        attcls_print_loss_total = 0  # Reset every print_every
        attcls_epoch_loss_total = 0  # Reset every epoch

        step = start_step
        step_elapsed = 0
        prev_acc = 0.0
        count_no_improve = 0
        count_num_rollback = 0
        ckpt = None

        # ******************** [loop over epochs] ********************
        for epoch in range(start_epoch, n_epochs + 1):

            for param_group in self.optimizer.optimizer.param_groups:
                print('epoch:{} lr: {}'.format(epoch, param_group['lr']))
                lr_curr = param_group['lr']

            # ----------construct batches-----------
            # allow re-shuffling of data
            if type(train_set.attkey_path) == type(None):
                print('--- construct train set ---')
                train_batches, vocab_size = train_set.construct_batches(
                    is_train=True)
                if dev_set is not None:
                    print('--- construct dev set ---')
                    dev_batches, vocab_size = dev_set.construct_batches(
                        is_train=False)
            else:
                print('--- construct train set ---')
                train_batches, vocab_size = train_set.construct_batches_with_ddfd_prob(
                    is_train=True)
                if dev_set is not None:
                    print('--- construct dev set ---')
                    assert type(dev_set.attkey_path) != type(
                        None), 'Dev set missing ddfd probabilities'
                    dev_batches, vocab_size = dev_set.construct_batches_with_ddfd_prob(
                        is_train=False)

            # --------print info for each epoch----------
            steps_per_epoch = len(train_batches)
            total_steps = steps_per_epoch * n_epochs
            log.info("steps_per_epoch {}".format(steps_per_epoch))
            log.info("total_steps {}".format(total_steps))

            log.debug(
                " ----------------- Epoch: %d, Step: %d -----------------" %
                (epoch, step))
            mem_kb, mem_mb, mem_gb = get_memory_alloc()
            mem_mb = round(mem_mb, 2)
            print('Memory used: {0:.2f} MB'.format(mem_mb))
            self.writer.add_scalar('Memory_MB', mem_mb, global_step=step)
            sys.stdout.flush()

            # ******************** [loop over batches] ********************
            model.train(True)
            for batch in train_batches:

                # update macro count
                step += 1
                step_elapsed += 1

                # load data
                src_ids = batch['src_word_ids']
                src_lengths = batch['src_sentence_lengths']
                tgt_ids = batch['tgt_word_ids']
                tgt_lengths = batch['tgt_sentence_lengths']

                src_probs = None
                src_labs = None
                if 'src_ddfd_probs' in batch and model.additional_key_size > 0:
                    src_probs = batch['src_ddfd_probs']
                    src_probs = _convert_to_tensor(src_probs,
                                                   self.use_gpu).unsqueeze(2)
                if 'src_ddfd_labs' in batch:
                    src_labs = batch['src_ddfd_labs']
                    src_labs = _convert_to_tensor(src_labs,
                                                  self.use_gpu).unsqueeze(2)

                # sanity check src-tgt pair
                if step == 1:
                    print('--- Check src tgt pair ---')
                    log_msgs = check_srctgt(src_ids, tgt_ids,
                                            train_set.src_id2word,
                                            train_set.tgt_id2word)
                    for log_msg in log_msgs:
                        sys.stdout.buffer.write(log_msg)

                # convert variable to tensor
                src_ids = _convert_to_tensor(src_ids, self.use_gpu)
                tgt_ids = _convert_to_tensor(tgt_ids, self.use_gpu)

                # Get loss
                loss, att_loss, attcls_loss = self._train_batch(
                    src_ids,
                    tgt_ids,
                    model,
                    step,
                    total_steps,
                    src_probs=src_probs,
                    src_labs=src_labs)

                print_loss_total += loss
                epoch_loss_total += loss
                att_print_loss_total += att_loss
                att_epoch_loss_total += att_loss
                attcls_print_loss_total += attcls_loss
                attcls_epoch_loss_total += attcls_loss

                if step % self.print_every == 0 and step_elapsed > self.print_every:
                    print_loss_avg = print_loss_total / self.print_every
                    att_print_loss_avg = att_print_loss_total / self.print_every
                    attcls_print_loss_avg = attcls_print_loss_total / self.print_every
                    print_loss_total = 0
                    att_print_loss_total = 0
                    attcls_print_loss_total = 0

                    log_msg = 'Progress: %d%%, Train nlll: %.4f, att: %.4f, attcls: %.4f' % (
                        step / total_steps * 100, print_loss_avg,
                        att_print_loss_avg, attcls_print_loss_avg)
                    # print(log_msg)
                    log.info(log_msg)
                    self.writer.add_scalar('train_loss',
                                           print_loss_avg,
                                           global_step=step)
                    self.writer.add_scalar('att_train_loss',
                                           att_print_loss_avg,
                                           global_step=step)
                    self.writer.add_scalar('attcls_train_loss',
                                           attcls_print_loss_avg,
                                           global_step=step)

                # Checkpoint
                if step % self.checkpoint_every == 0 or step == total_steps:

                    # save criteria
                    if dev_set is not None:
                        dev_loss, accuracy, dev_attlosses = \
                         self._evaluate_batches(model, dev_batches, dev_set)
                        dev_attloss = dev_attlosses['att_loss']
                        dev_attclsloss = dev_attlosses['attcls_loss']
                        log_msg = 'Progress: %d%%, Dev loss: %.4f, accuracy: %.4f, att: %.4f, attcls: %.4f' % (
                            step / total_steps * 100, dev_loss, accuracy,
                            dev_attloss, dev_attclsloss)
                        log.info(log_msg)
                        self.writer.add_scalar('dev_loss',
                                               dev_loss,
                                               global_step=step)
                        self.writer.add_scalar('dev_acc',
                                               accuracy,
                                               global_step=step)
                        self.writer.add_scalar('att_dev_loss',
                                               dev_attloss,
                                               global_step=step)
                        self.writer.add_scalar('attcls_dev_loss',
                                               dev_attclsloss,
                                               global_step=step)

                        # save
                        if prev_acc < accuracy:
                            # save best model
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)

                            saved_path = ckpt.save(self.expt_dir)
                            print('saving at {} ... '.format(saved_path))
                            # reset
                            prev_acc = accuracy
                            count_no_improve = 0
                            count_num_rollback = 0
                        else:
                            count_no_improve += 1

                        # roll back
                        if count_no_improve > MAX_COUNT_NO_IMPROVE:
                            # resuming
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                print(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim\
                                 .__class__(model.parameters(), **defaults)
                                # start_epoch = resume_checkpoint.epoch
                                # step = resume_checkpoint.step

                            # reset
                            count_no_improve = 0
                            count_num_rollback += 1

                        # update learning rate
                        if count_num_rollback > MAX_COUNT_NUM_ROLLBACK:

                            # roll back
                            latest_checkpoint_path = Checkpoint.get_latest_checkpoint(
                                self.expt_dir)
                            if type(latest_checkpoint_path) != type(None):
                                resume_checkpoint = Checkpoint.load(
                                    latest_checkpoint_path)
                                print(
                                    'epoch:{} step: {} - rolling back {} ...'.
                                    format(epoch, step,
                                           latest_checkpoint_path))
                                model = resume_checkpoint.model
                                self.optimizer = resume_checkpoint.optimizer
                                # A walk around to set optimizing parameters properly
                                resume_optim = self.optimizer.optimizer
                                defaults = resume_optim.param_groups[0]
                                defaults.pop('params', None)
                                defaults.pop('initial_lr', None)
                                self.optimizer.optimizer = resume_optim\
                                 .__class__(model.parameters(), **defaults)
                                start_epoch = resume_checkpoint.epoch
                                step = resume_checkpoint.step

                            # decrease lr
                            for param_group in self.optimizer.optimizer.param_groups:
                                param_group['lr'] *= 0.5
                                lr_curr = param_group['lr']
                                print('reducing lr ...')
                                print('step:{} - lr: {}'.format(
                                    step, param_group['lr']))

                            # check early stop
                            if lr_curr < 0.000125:
                                print('early stop ...')
                                break

                            # reset
                            count_no_improve = 0
                            count_num_rollback = 0

                        model.train(mode=True)
                        if ckpt is None:
                            ckpt = Checkpoint(model=model,
                                              optimizer=self.optimizer,
                                              epoch=epoch,
                                              step=step,
                                              input_vocab=train_set.vocab_src,
                                              output_vocab=train_set.vocab_tgt)
                            saved_path = ckpt.save(self.expt_dir)
                        ckpt.rm_old(self.expt_dir, keep_num=KEEP_NUM)
                        print('n_no_improve {}, num_rollback {}'.format(
                            count_no_improve, count_num_rollback))
                    sys.stdout.flush()

            else:
                if dev_set is None:
                    # save every epoch if no dev_set
                    ckpt = Checkpoint(model=model,
                                      optimizer=self.optimizer,
                                      epoch=epoch,
                                      step=step,
                                      input_vocab=train_set.vocab_src,
                                      output_vocab=train_set.vocab_tgt)
                    # saved_path = ckpt.save(self.expt_dir)
                    saved_path = ckpt.save_epoch(self.expt_dir, epoch)
                    print('saving at {} ... '.format(saved_path))
                    continue

                else:
                    continue
            # break nested for loop
            break

            if step_elapsed == 0: continue
            epoch_loss_avg = epoch_loss_total / min(steps_per_epoch,
                                                    step - start_step)
            epoch_loss_total = 0
            log_msg = "Finished epoch %d: Train %s: %.4f" % (
                epoch, self.loss.name, epoch_loss_avg)

            log.info('\n')
            log.info(log_msg)