Beispiel #1
0
    def run(self, n_iters, model_path=None):
        iter, running_avg_loss = self.setup_train(model_path)
        start = time.time()
        interval = 100
        prev_eval_loss = float("inf")
        while (time.time() - start) / 3600 <= 11.0:  #iter < n_iters:
            batch = self.batcher.next_batch()
            loss, cove_loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % interval == 0:
                self.summary_writer.flush()
                print('step: %d, second: %.2f , loss: %f, cover_loss: %f' %
                      (iter, time.time() - start, loss, cove_loss))
                start = time.time()
            if iter % 20000 == 0:
                self.save_model(running_avg_loss, iter, 'model_temp')
                eval_loss = Evaluate(os.path.join(self.model_dir,
                                                  'model_temp')).run()
                if eval_loss < prev_eval_loss:
                    print(
                        f"eval loss for iteration: {iter} is {eval_loss}, previous best eval loss = {prev_eval_loss}, saving checkpoint..."
                    )
                    prev_eval_loss = eval_loss
                    self.save_model(running_avg_loss, iter)
                else:
                    print(
                        f"eval loss for iteration: {iter}, previous best eval loss = {prev_eval_loss}, no improvement, skipping..."
                    )
    def train_iters(self, n_iters, args):
        start_iter, running_avg_loss, best_val_loss = self.setup_train(args)
        logger = self.setup_logging()
        logger.debug(str(args))
        logger.debug(str(config))

        start = time.time()
        # best_val_loss = None

        for it in tqdm(range(n_iters), dynamic_ncols=True):
            iter = start_iter + it
            self.model.module.train()
            batch = self.train_batcher.next_batch()
            start1 = time.time()
            loss = self.train_one_batch(batch, args)
            #print("time for 1 batch+get: "+str(time.time() - start))
            #print("time for 1 batch: "+str(time.time() - start1))
            #start=time.time()
            #print(loss)
            # for n,p in self.model.module.encoder.named_parameters():
            #     print('===========\ngradient:{}\n----------\n{}'.format(n,p.grad))
            # exit()
            if math.isnan(loss):
                msg = "Loss has reached NAN. Exiting"
                logger.debug(msg)
                print(msg)
                exit()
            if loss is not None:
                running_avg_loss = calc_running_avg_loss(
                    loss, running_avg_loss, iter)
                iter += 1

            print_interval = 200
            if iter % print_interval == 0:
                msg = 'steps %d, seconds for %d batch: %.2f , loss: %f' % (
                    iter, print_interval, time.time() - start, loss)
                print(msg)
                logger.debug(msg)
                start = time.time()
            if iter % config.eval_interval == 0:
                print("Starting Eval")
                loss = self.run_eval(logger, args)
                if best_val_loss is None or loss < best_val_loss:
                    best_val_loss = loss
                    self.save_model(running_avg_loss, iter, logger,
                                    best_val_loss)
                    print("Saving best model")
                    logger.debug("Saving best model")
Beispiel #3
0
    def run(self):
        start = time.time()
        running_avg_loss, iter = 0, 0
        batch = self.batcher.next_batch()
        print_interval = 100
        while batch is not None:
            loss = self.eval_one_batch(batch)
            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            if iter % print_interval == 0:
                self.summary_writer.flush()
                print('step: %d, second: %.2f , loss: %f' %
                      (iter, time.time() - start, running_avg_loss))
                start = time.time()
            batch = self.batcher.next_batch()

        return running_avg_loss
Beispiel #4
0
    def run(self, n_iters, model_path=None):
        iter, running_avg_loss = self.setup_train(model_path)
        start = time.time()
        interval = 100

        while iter < n_iters:
            batch = self.batcher.next_batch()
            loss, cove_loss = self.train_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, self.summary_writer, iter)
            iter += 1

            if iter % interval == 0:
                self.summary_writer.flush()
                print(
                    'step: %d, second: %.2f , loss: %f, cover_loss: %f' % (iter, time.time() - start, loss, cove_loss))
                start = time.time()
            if iter % 5000 == 0:
                self.save_model(running_avg_loss, iter)
Beispiel #5
0
    def run_eval(self):
        running_avg_loss, iter = 0, 0
        start = time.time()
        batch = self.batcher.next_batch()
        while batch is not None:
            loss = self.eval_one_batch(batch)

            running_avg_loss = calc_running_avg_loss(loss, running_avg_loss,
                                                     self.summary_writer, iter)
            iter += 1

            # if iter % 100 == 0:
            #     self.summary_writer.flush()
            print_interval = 1000
            if iter % print_interval == 0:
                print('steps %d, seconds for %d batch: %.2f , loss: %f' %
                      (iter, print_interval, time.time() - start,
                       running_avg_loss))
                start = time.time()
            batch = self.batcher.next_batch()
    def run_eval(self, logger, args):
        running_avg_loss, iter = 0, 0
        run_avg_losses = {
            'summ_loss': 0,
            'sent_single_head_loss': 0,
            'sent_all_head_loss': 0,
            'sent_all_child_loss': 0,
            'token_contsel_loss': 0,
            'sent_imp_loss': 0,
            'doc_imp_loss': 0
        }
        counts = {
            'token_consel_num_correct': 0,
            'token_consel_num': 0,
            'sent_single_heads_num_correct': 0,
            'sent_single_heads_num': 0,
            'sent_all_heads_num_correct': 0,
            'sent_all_heads_num': 0,
            'sent_all_heads_num_correct_1': 0,
            'sent_all_heads_num_1': 0,
            'sent_all_heads_num_correct_0': 0,
            'sent_all_heads_num_0': 0,
            'sent_all_child_num_correct': 0,
            'sent_all_child_num': 0,
            'sent_all_child_num_correct_1': 0,
            'sent_all_child_num_1': 0,
            'sent_all_child_num_correct_0': 0,
            'sent_all_child_num_0': 0
        }
        eval_res = {
            'sent_all_heads_pred': [],
            'sent_all_heads_true': [],
            'sent_all_child_pred': [],
            'sent_all_child_true': [],
        }
        self.model.module.eval()
        self.eval_batcher._finished_reading = False
        self.eval_batcher.setup_queues()
        batch = self.eval_batcher.next_batch()
        while batch is not None:
            loss, sample_ind_losses, sample_counts, eval_data = self.get_loss(
                batch, args, mode='eval')
            loss = loss.item()
            if loss is not None:
                running_avg_loss = calc_running_avg_loss(
                    loss, running_avg_loss, iter)

                if args.use_summ_loss:
                    run_avg_losses['summ_loss'] = calc_running_avg_loss(
                        sample_ind_losses['summ_loss'],
                        run_avg_losses['summ_loss'], iter)
                if args.use_sent_single_head_loss:
                    run_avg_losses[
                        'sent_single_head_loss'] = calc_running_avg_loss(
                            sample_ind_losses['sent_single_head_loss'],
                            run_avg_losses['sent_single_head_loss'], iter)
                    counts['sent_single_heads_num_correct'] += sample_counts[
                        'sent_single_heads_num_correct']
                    counts['sent_single_heads_num'] += sample_counts[
                        'sent_single_heads_num']
                if args.use_sent_all_head_loss:
                    run_avg_losses[
                        'sent_all_head_loss'] = calc_running_avg_loss(
                            sample_ind_losses['sent_all_head_loss'],
                            run_avg_losses['sent_all_head_loss'], iter)
                    counts['sent_all_heads_num_correct'] += sample_counts[
                        'sent_all_heads_num_correct']
                    counts['sent_all_heads_num'] += sample_counts[
                        'sent_all_heads_num']
                    counts['sent_all_heads_num_correct_1'] += sample_counts[
                        'sent_all_heads_num_correct_1']
                    counts['sent_all_heads_num_1'] += sample_counts[
                        'sent_all_heads_num_1']
                    counts['sent_all_heads_num_correct_0'] += sample_counts[
                        'sent_all_heads_num_correct_0']
                    counts['sent_all_heads_num_0'] += sample_counts[
                        'sent_all_heads_num_0']
                    eval_res['sent_all_heads_pred'].append(
                        eval_data['sent_all_heads_pred'])
                    eval_res['sent_all_heads_true'].append(
                        eval_data['sent_all_heads_true'])
                if args.use_sent_all_child_loss:
                    run_avg_losses[
                        'sent_all_child_loss'] = calc_running_avg_loss(
                            sample_ind_losses['sent_all_child_loss'],
                            run_avg_losses['sent_all_child_loss'], iter)
                    counts['sent_all_child_num_correct'] += sample_counts[
                        'sent_all_child_num_correct']
                    counts['sent_all_child_num'] += sample_counts[
                        'sent_all_child_num']
                    counts['sent_all_child_num_correct_1'] += sample_counts[
                        'sent_all_child_num_correct_1']
                    counts['sent_all_child_num_1'] += sample_counts[
                        'sent_all_child_num_1']
                    counts['sent_all_child_num_correct_0'] += sample_counts[
                        'sent_all_child_num_correct_0']
                    counts['sent_all_child_num_0'] += sample_counts[
                        'sent_all_child_num_0']
                    eval_res['sent_all_child_pred'].append(
                        eval_data['sent_all_child_pred'])
                    eval_res['sent_all_child_true'].append(
                        eval_data['sent_all_child_true'])
                if args.use_token_contsel_loss:
                    run_avg_losses[
                        'token_contsel_loss'] = calc_running_avg_loss(
                            sample_ind_losses['token_contsel_loss'],
                            run_avg_losses['token_contsel_loss'], iter)
                    counts['token_consel_num_correct'] += sample_counts[
                        'token_consel_num_correct']
                    counts['token_consel_num'] += sample_counts[
                        'token_consel_num']
                if args.use_sent_imp_loss:
                    run_avg_losses['sent_imp_loss'] = calc_running_avg_loss(
                        sample_ind_losses['sent_imp_loss'],
                        run_avg_losses['sent_imp_loss'], iter)
                if args.use_doc_imp_loss:
                    run_avg_losses['doc_imp_loss'] = calc_running_avg_loss(
                        sample_ind_losses['doc_imp_loss'],
                        run_avg_losses['doc_imp_loss'], iter)
                iter += 1
            batch = self.eval_batcher.next_batch()

        msg = 'Eval: loss: %f' % running_avg_loss
        print(msg)
        logger.debug(msg)

        if args.use_summ_loss:
            msg = 'Summ Eval: loss: %f' % run_avg_losses['summ_loss']
            print(msg)
            logger.debug(msg)
        if args.use_sent_single_head_loss:
            msg = 'Single Sent Head Eval: loss: %f' % run_avg_losses[
                'sent_single_head_loss']
            print(msg)
            logger.debug(msg)
            msg = 'Average Sent Single Head Accuracy: %f' % (
                counts['sent_single_heads_num_correct'] /
                float(counts['sent_single_heads_num']))
            print(msg)
            logger.debug(msg)
        if args.use_sent_all_head_loss:
            msg = 'All Sent Head Eval: loss: %f' % run_avg_losses[
                'sent_all_head_loss']
            print(msg)
            logger.debug(msg)
            msg = 'Average Sent All Head Accuracy: %f' % (
                counts['sent_all_heads_num_correct'] /
                float(counts['sent_all_heads_num']))
            print(msg)
            logger.debug(msg)
            # msg = 'Average Sent All Head Class1 Accuracy: %f' % (counts['sent_all_heads_num_correct_1']/float(counts['sent_all_heads_num_1']))
            # print(msg)
            # logger.debug(msg)
            # msg = 'Average Sent All Head Class0 Accuracy: %f' % (counts['sent_all_heads_num_correct_0']/float(counts['sent_all_heads_num_0']))
            # print(msg)
            # logger.debug(msg)
            y_pred = np.concatenate(eval_res['sent_all_heads_pred'])
            y_true = np.concatenate(eval_res['sent_all_heads_true'])
            msg = classification_report(y_true, y_pred, labels=[0, 1])
            print(msg)
            logger.debug(msg)

        if args.use_sent_all_child_loss:
            msg = 'All Sent Child Eval: loss: %f' % run_avg_losses[
                'sent_all_child_loss']
            print(msg)
            logger.debug(msg)
            msg = 'Average Sent All Child Accuracy: %f' % (
                counts['sent_all_child_num_correct'] /
                float(counts['sent_all_child_num']))
            print(msg)
            logger.debug(msg)
            # msg = 'Average Sent All Child Class1 Accuracy: %f' % (counts['sent_all_child_num_correct_1']/float(counts['sent_all_child_num_1']))
            # print(msg)
            # logger.debug(msg)
            # msg = 'Average Sent All Child Class0 Accuracy: %f' % (counts['sent_all_child_num_correct_0']/float(counts['sent_all_child_num_0']))
            # print(msg)
            # logger.debug(msg)
            y_pred = np.concatenate(eval_res['sent_all_child_pred'])
            y_true = np.concatenate(eval_res['sent_all_child_true'])
            msg = classification_report(y_true, y_pred, labels=[0, 1])
            print(msg)
            logger.debug(msg)
        if args.use_token_contsel_loss:
            msg = 'Token Contsel Eval: loss: %f' % run_avg_losses[
                'token_contsel_loss']
            print(msg)
            logger.debug(msg)
            msg = 'Average token content sel Accuracy: %f' % (
                counts['token_consel_num_correct'] /
                float(counts['token_consel_num']))
            print(msg)
            logger.debug(msg)
        if args.use_sent_imp_loss:
            msg = 'Sent Imp Eval: loss: %f' % run_avg_losses['sent_imp_loss']
            print(msg)
            logger.debug(msg)
        if args.use_doc_imp_loss:
            msg = 'Doc Imp Eval: loss: %f' % run_avg_losses['doc_imp_loss']
            print(msg)
            logger.debug(msg)

        return running_avg_loss