Beispiel #1
0
def run_train_epoch(args, global_step, model, param_optimizer, train_dataloader,
                    eval_examples, eval_features, eval_dataloader,
                    optimizer, n_gpu, device, logger, log_path, save_path,
                    save_checkpoints_steps, start_save_steps, best_f1):
    running_loss, count = 0.0, 0
    for step, batch in enumerate(train_dataloader):
        if n_gpu == 1:
            batch = tuple(t.to(device) for t in batch)  # multi-gpu does scattering it-self
        input_ids, input_mask, segment_ids, span_starts, span_ends, labels, label_masks = batch
        loss = model('train', input_mask, input_ids=input_ids, token_type_ids=segment_ids,
                     span_starts=span_starts, span_ends=span_ends, labels=labels, label_masks=label_masks)
        loss = post_process_loss(args, n_gpu, loss)
        loss.backward()
        running_loss += loss.item()

        if (step + 1) % args.gradient_accumulation_steps == 0:
            if args.fp16 or args.optimize_on_cpu:
                if args.fp16 and args.loss_scale != 1.0:
                    # scale down gradients for fp16 training
                    for param in model.parameters():
                        param.grad.data = param.grad.data / args.loss_scale
                is_nan = set_optimizer_params_grad(param_optimizer, model.named_parameters(), test_nan=True)
                if is_nan:
                    logger.info("FP16 TRAINING: Nan in gradients, reducing loss scaling")
                    args.loss_scale = args.loss_scale / 2
                    model.zero_grad()
                    continue
                optimizer.step()
                copy_optimizer_params_to_model(model.named_parameters(), param_optimizer)
            else:
                optimizer.step()
            model.zero_grad()
            global_step += 1
            count += 1

            if global_step % save_checkpoints_steps == 0 and count != 0:
                logger.info("step: {}, loss: {:.4f}".format(global_step, running_loss / count))

            if global_step % save_checkpoints_steps == 0 and global_step > start_save_steps and count != 0:  # eval & save model
                logger.info("***** Running evaluation *****")
                model.eval()
                metrics = evaluate(args, model, device, eval_examples, eval_features, eval_dataloader, logger)
                f = open(log_path, "a")
                print("step: {}, loss: {:.4f}, P: {:.4f}, R: {:.4f}, F1: {:.4f} "
                      "(common: {}, retrieved: {}, relevant: {})"
                      .format(global_step, running_loss / count, metrics['p'], metrics['r'],
                              metrics['f1'], metrics['common'], metrics['retrieved'], metrics['relevant']), file=f)
                print(" ", file=f)
                f.close()
                running_loss, count = 0.0, 0
                model.train()
                if metrics['f1'] > best_f1:
                    best_f1 = metrics['f1']
                    torch.save({
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'step': global_step
                    }, save_path)
                if args.debug:
                    break
    return global_step, model, best_f1
Beispiel #2
0
def run_train_epoch(args, global_step, model, param_optimizer, train_examples,
                    train_features, train_dataloader, eval_examples,
                    eval_features, eval_dataloader, optimizer, n_gpu, device,
                    logger, log_path, save_path, save_checkpoints_steps,
                    start_save_steps, best_f1):
    running_loss, running_te_loss, running_tc_loss, count = 0.0, 0.0, 0.0, 0
    for step, batch in enumerate(train_dataloader):
        if n_gpu == 1:
            batch = tuple(t.to(device)
                          for t in batch)  # multi-gpu does scattering it-self
        input_ids, input_mask, segment_ids, start_positions, end_positions, aspect_num, example_indices = batch
        batch_start_logits, batch_end_logits, batch_aspect_num, _ = model(
            'extract_inference',
            input_mask,
            input_ids=input_ids,
            token_type_ids=segment_ids,
            logger=logger)

        batch_features, batch_results = [], []
        for j, example_index in enumerate(example_indices):
            start_logits = batch_start_logits[j].detach().cpu().tolist()
            end_logits = batch_end_logits[j].detach().cpu().tolist()
            aspect_num_logits = batch_aspect_num[j].detach().cpu().tolist()
            train_feature = train_features[example_index.item()]
            unique_id = int(train_feature.unique_id)
            batch_features.append(train_feature)
            batch_results.append(
                RawSpanResult(unique_id=unique_id,
                              start_logits=start_logits,
                              end_logits=end_logits,
                              aspect_num_logits=aspect_num_logits))

        span_starts, span_ends, span_aspect_num, labels, label_masks = span_annotate_candidates(
            train_examples, batch_features, batch_results, args.filter_type,
            True, args.use_heuristics, args.use_nms, args.n_best_size,
            args.max_answer_length, args.do_lower_case, args.verbose_logging,
            logger)

        span_starts = torch.tensor(span_starts, dtype=torch.long)
        span_ends = torch.tensor(span_ends, dtype=torch.long)
        span_aspect_num = torch.tensor(span_aspect_num, dtype=torch.long)
        labels = torch.tensor(labels, dtype=torch.long)
        label_masks = torch.tensor(label_masks, dtype=torch.long)
        span_starts = span_starts.to(device)
        span_ends = span_ends.to(device)

        labels = labels.to(device)
        label_masks = label_masks.to(device)

        te_loss, tc_loss = model('train',
                                 input_mask,
                                 input_ids=input_ids,
                                 token_type_ids=segment_ids,
                                 start_positions=start_positions,
                                 end_positions=end_positions,
                                 aspect_num=aspect_num,
                                 span_starts=span_starts,
                                 span_ends=span_ends,
                                 span_aspect_num=span_aspect_num,
                                 polarity_labels=labels,
                                 label_masks=label_masks,
                                 logger=logger)
        te_loss = post_process_loss(args, n_gpu, te_loss)
        tc_loss = post_process_loss(args, n_gpu, tc_loss)
        loss = post_process_loss(args, n_gpu, te_loss + tc_loss)
        loss.backward()
        running_te_loss += te_loss.item()
        running_tc_loss += tc_loss.item()
        running_loss += loss.item()

        if (step + 1) % args.gradient_accumulation_steps == 0:
            if args.fp16 or args.optimize_on_cpu:
                if args.fp16 and args.loss_scale != 1.0:
                    # scale down gradients for fp16 training
                    for param in model.parameters():
                        param.grad.data = param.grad.data / args.loss_scale
                is_nan = set_optimizer_params_grad(param_optimizer,
                                                   model.named_parameters(),
                                                   test_nan=True)
                if is_nan:
                    logger.info(
                        "FP16 TRAINING: Nan in gradients, reducing loss scaling"
                    )
                    args.loss_scale = args.loss_scale / 2
                    model.zero_grad()
                    continue
                optimizer.step()
                copy_optimizer_params_to_model(model.named_parameters(),
                                               param_optimizer)
            else:
                optimizer.step()
            model.zero_grad()
            global_step += 1
            count += 1

            if global_step % save_checkpoints_steps == 0 and count != 0:
                logger.info(
                    "step: {}, loss: {:.4f}, ae loss: {:.4f}, ac loss: {:.4f}".
                    format(global_step, running_loss / count,
                           running_te_loss / count, running_tc_loss / count))

            if global_step % save_checkpoints_steps == 0 and global_step > start_save_steps and count != 0:  # eval & save model
                logger.info("***** Running evaluation *****")
                model.eval()
                metrics = evaluate(args, model, device, eval_examples,
                                   eval_features, eval_dataloader, logger)
                f = open(log_path, "a")
                print(
                    "step: {}, loss: {:.4f}, ae loss: {:.4f}, ac loss: {:.4f}, P: {:.4f}, R: {:.4f}, F1: {:.4f} (common: {}, retrieved: {}, relevant: {})"
                    .format(global_step, running_loss / count,
                            running_te_loss / count, running_tc_loss / count,
                            metrics['p'], metrics['r'], metrics['f1'],
                            metrics['common'], metrics['retrieved'],
                            metrics['relevant']),
                    file=f)
                print(" ", file=f)
                f.close()
                running_loss, running_te_loss, running_tc_loss, count = 0.0, 0.0, 0.0, 0
                model.train()
                if metrics['f1'] > best_f1:
                    best_f1 = metrics['f1']
                    torch.save(
                        {
                            'model': model.state_dict(),
                            'optimizer': optimizer.state_dict(),
                            'step': global_step
                        }, save_path)
                if args.debug:
                    break
    return global_step, model, best_f1