예제 #1
0
    # -------------------------------------------------------------------------------------------
    # start training
    # -------------------------------------------------------------------------------------------
    # 8.开始训练
    model.zero_grad()
    stats = {'timer': utils.Timer(), 'epoch': 0, 'min_eval_loss': float("inf")}

    for epoch in range(1, (args.max_train_epochs + 1)):
        stats['epoch'] = epoch

        # train
        train(args, train_data_loader, model, stats, tb_writer)

        # eval & test
        eval_loss = evaluate(args, valid_data_loader, model, stats, tb_writer)
        if eval_loss < stats['min_eval_loss']:
            stats['min_eval_loss'] = eval_loss
            logger.info(
                " *********************************************************************** "
            )
            logger.info('Update Min Eval_Loss = %.6f (epoch = %d)' %
                        (stats['min_eval_loss'], stats['epoch']))

        # Checkpoint
        if args.save_checkpoint:
            model.save_checkpoint(
                os.path.join(args.output_folder,
                             'epoch_{}.checkpoint'.format(epoch)),
                stats['epoch'])
예제 #2
0
    # start training
    # -------------------------------------------------------------------------------------------
    model.zero_grad()
    stats = {'timer': utils.Timer(), 'epoch': 0, main_metric_name: 0}
    for epoch in range(1, (args.max_train_epochs+1)):
        stats['epoch'] = epoch
        
        # train 
        train(args, train_data_loader, model, train_input_refactor, stats, tb_writer)

        # previous metric score
        prev_metric_score = stats[main_metric_name]
        
        # decode candidate phrases
        dev_candidate = candidate_decoder(args, dev_data_loader, dev_dataset, model, test_input_refactor, pred_arranger, 'dev')
        stats = evaluate_script(args, dev_candidate, stats, mode='dev', metric_name=main_metric_name)
            
        # new metric score
        new_metric_score = stats[main_metric_name]
            
        # save checkpoint : when new metric score > previous metric score
        if args.save_checkpoint and (new_metric_score > prev_metric_score) and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
            checkpoint_name = '{}.{}.{}.epoch_{}.checkpoint'.format(args.model_class, args.dataset_class, args.pretrain_model_type.split('-')[0], epoch)
            model.save_checkpoint(os.path.join(args.checkpoint_folder, checkpoint_name), stats['epoch'])
        
        # eval evaluation
        if args.dataset_class == 'kp20k':
            eval_candidate = candidate_decoder(args, eval_data_loader, eval_dataset, model, test_input_refactor, pred_arranger, 'eval')
            eval_stats = {'epoch': epoch, main_metric_name: 0}
            eval_stats = evaluate_script(args, eval_candidate, eval_stats, mode='eval', metric_name=main_metric_name)