コード例 #1
0
def train(args):
    """
    :param args:
    :return:
    """
    grammar = semQL.Grammar()
    sql_data, table_data, val_sql_data, val_table_data = utils.load_dataset(
        args.dataset, use_small=args.toy)

    model = IRNet(args, grammar)
    if args.cuda: model.cuda()

    # now get the optimizer
    optimizer_cls = eval('torch.optim.%s' % args.optimizer)
    optimizer = optimizer_cls(model.parameters(), lr=args.lr)
    print('Enable Learning Rate Scheduler: ', args.lr_scheduler)
    if args.lr_scheduler:
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=[21, 41], gamma=args.lr_scheduler_gammar)
    else:
        scheduler = None

    print('Loss epoch threshold: %d' % args.loss_epoch_threshold)
    print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient)

    if args.load_model:
        print('load pretrained model from %s' % (args.load_model))
        pretrained_model = torch.load(
            args.load_model, map_location=lambda storage, loc: storage)
        pretrained_modeled = copy.deepcopy(pretrained_model)
        for k in pretrained_model.keys():
            if k not in model.state_dict().keys():
                del pretrained_modeled[k]

        model.load_state_dict(pretrained_modeled)

    model.word_emb = utils.load_word_emb(args.glove_embed_path)
    # begin train

    model_save_path = utils.init_log_checkpoint_path(args)
    utils.save_args(args, os.path.join(model_save_path, 'config.json'))
    best_dev_acc = .0

    try:
        with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd:
            for epoch in tqdm.tqdm(range(args.epoch)):
                if args.lr_scheduler:
                    scheduler.step()
                epoch_begin = time.time()
                loss = utils.epoch_train(
                    model,
                    optimizer,
                    args.batch_size,
                    sql_data,
                    table_data,
                    args,
                    loss_epoch_threshold=args.loss_epoch_threshold,
                    sketch_loss_coefficient=args.sketch_loss_coefficient)
                epoch_end = time.time()
                json_datas, sketch_acc, acc, counts, corrects = utils.epoch_acc(
                    model,
                    args.batch_size,
                    val_sql_data,
                    val_table_data,
                    beam_size=args.beam_size)
                # acc = utils.eval_acc(json_datas, val_sql_data)

                if acc > best_dev_acc:
                    utils.save_checkpoint(
                        model, os.path.join(model_save_path,
                                            'best_model.model'))
                    best_dev_acc = acc
                utils.save_checkpoint(
                    model,
                    os.path.join(model_save_path, '{%s}_{%s}.model') %
                    (epoch, acc))

                log_str = 'Epoch: %d, Loss: %f, Sketch Acc: %f, Acc: %f, time: %f\n' % (
                    epoch + 1, loss, sketch_acc, acc, epoch_end - epoch_begin)
                tqdm.tqdm.write(log_str)
                epoch_fd.write(log_str)
                epoch_fd.flush()
    except Exception as e:
        # Save model
        utils.save_checkpoint(model,
                              os.path.join(model_save_path, 'end_model.model'))
        print(e)
        tb = traceback.format_exc()
        print(tb)
    else:
        utils.save_checkpoint(model,
                              os.path.join(model_save_path, 'end_model.model'))
        json_datas, sketch_acc, acc, counts, corrects = utils.epoch_acc(
            model,
            args.batch_size,
            val_sql_data,
            val_table_data,
            beam_size=args.beam_size)
        # acc = utils.eval_acc(json_datas, val_sql_data)

        print("Sketch Acc: %f, Acc: %f, Beam Acc: %f" % (
            sketch_acc,
            acc,
            acc,
        ))
コード例 #2
0
def train(args, model, optimizer, bert_optimizer, data):
    '''
    :param args:
    :param model:
    :param data:
    :param grammar:
    :return:
    '''

    print('Loss epoch threshold: %d' % args.loss_epoch_threshold)
    print('Sketch loss coefficient: %f' % args.sketch_loss_coefficient)

    if args.load_model and not args.resume:
        print('load pretrained model from %s' % (args.load_model))
        pretrained_model = torch.load(
            args.load_model, map_location=lambda storage, loc: storage)
        pretrained_modeled = copy.deepcopy(pretrained_model)
        for k in pretrained_model.keys():
            if k not in model.state_dict().keys():
                del pretrained_modeled[k]

        model.load_state_dict(pretrained_modeled)

    # ==============data==============
    if args.interaction_level:
        batch_size = 1
        train_batchs, train_sample_batchs = data.get_interaction_batches(
            batch_size,
            sample_num=args.train_evaluation_size,
            use_small=args.toy)
        valid_batchs = data.get_all_interactions(data.val_sql_data,
                                                 data.val_table_data,
                                                 _type='test',
                                                 use_small=args.toy)

    else:
        batch_size = args.batch_size
        train_batchs = data.get_utterance_batches(batch_size)
        valid_batchs = data.get_all_utterances(data.val_sql_data)
    print(len(train_batchs), len(train_sample_batchs), len(valid_batchs))
    start_epoch = 1
    best_question_match = .0
    lr = args.initial_lr
    stage = 1

    if args.resume:
        model_save_path = utils.init_log_checkpoint_path(args)
        current_w = torch.load(
            os.path.join(model_save_path, args.current_model_name))
        best_w = torch.load(os.path.join(model_save_path,
                                         args.best_model_name))
        best_question_match = best_w['question_match']
        start_epoch = current_w['epoch'] + 1
        lr = current_w['lr']
        utils.adjust_learning_rate(optimizer, lr)
        stage = current_w['stage']
        model.load_state_dict(current_w['state_dict'])
        # 如果中断点恰好为转换stage的点
        if start_epoch - 1 in args.stage_epoch:
            stage += 1
            lr /= args.lr_decay
            utils.adjust_learning_rate(optimizer, lr)
            model.load_state_dict(best_w['state_dict'])
        print("=> Loading resume model from epoch {} ...".format(start_epoch -
                                                                 1))

    # model.word_emb = utils.load_word_emb(args.glove_embed_path,use_small=args.use_small)
    # begin train

    model_save_path = utils.init_log_checkpoint_path(args)
    utils.save_args(args, os.path.join(model_save_path, 'config.json'))
    file_mode = 'a' if args.resume else 'w'
    log = Logger(os.path.join(args.save, args.logfile), file_mode)
    # log_pred_gt = Logger(os.path.join(args.save, args.log_pred_gt), file_mode)

    with open(os.path.join(model_save_path, 'epoch.log'), 'w') as epoch_fd:

        for epoch in range(start_epoch, args.epoch + 1):
            epoch_begin = time.time()
            # model.set_dropout(args.dropout_amount)

            model.dropout_ratio = args.dropout_amount

            if args.interaction_level:
                loss = utils.epoch_train_with_interaction(
                    epoch, log, model, optimizer, bert_optimizer, train_batchs,
                    args)
                #loss = 2.
            else:
                pass

            model.dropout_ratio = 0.
            # model.set_dropout(0.)
            epoch_end = time.time()

            s = time.time()
            sample_sketch_acc, sample_lf_acc, sample_interaction_lf_acc = utils.epoch_acc_with_interaction(
                epoch,
                model,
                train_sample_batchs,
                args,
                beam_size=args.beam_size,
                use_predicted_queries=True)
            log_str = '[Epoch: %d(sample predicted),Sample ratio[%d]: %f], Sketch Acc: %f, Acc: %f, Interaction Acc: %f, Train time: %f, Sample predict time: %f\n' % (
                epoch, len(train_sample_batchs),
                len(train_sample_batchs) / len(train_batchs),
                sample_sketch_acc, sample_lf_acc, sample_interaction_lf_acc,
                epoch_end - epoch_begin, time.time() - s)
            print(log_str)
            log.put(log_str)
            epoch_fd.write(log_str)
            epoch_fd.flush()

            # s = time.time()
            #
            # gold_sketch_acc,gold_lf_acc ,gold_interaction_lf_acc = utils.epoch_acc_with_interaction(epoch,model, valid_batchs,args,beam_size=args.beam_size)
            #
            # log_str = '[Epoch: %d(gold)], Loss: %f, Sketch Acc: %f, Acc: %f, Interaction Acc: %f, Gold predict time: %f\n' % (
            #     epoch, loss, gold_sketch_acc, gold_lf_acc, gold_interaction_lf_acc,time.time()-s)
            # print(log_str)
            # log.put(log_str)
            # epoch_fd.write(log_str)
            # epoch_fd.flush()

            s = time.time()
            valid_jsonf, pred_sketch_acc, pred_lf_acc, pred_interaction_lf_acc = utils.epoch_acc_with_interaction_save_json(
                epoch,
                model,
                valid_batchs,
                args,
                beam_size=args.beam_size,
                use_predicted_queries=True)

            question_match, interaction_match = utils.semQL2SQL_question_and_interaction_match(
                valid_jsonf, args)

            log_str = '[Epoch: %d(predicted)], Loss: %f, lr: %.3ef, Sketch Acc: %f, Acc: %f, Interaction Acc: %f, Question Match : %f, Interaction Macth : %f, Predicted predict time: %f\n\n' % (
                epoch, loss, optimizer.param_groups[0]["lr"], pred_sketch_acc,
                pred_lf_acc, pred_interaction_lf_acc, question_match,
                interaction_match, time.time() - s)
            print(log_str)
            log.put(log_str)
            epoch_fd.write(log_str)
            epoch_fd.flush()

            state = {
                "state_dict": model.state_dict(),
                "epoch": epoch,
                "question_match": question_match,
                "interaction_match": interaction_match,
                "lr": lr,
                'stage': stage
            }

            current_w_name = os.path.join(
                model_save_path, '{}_{:.3f}.pth'.format(epoch, question_match))
            best_w_name = os.path.join(model_save_path, args.best_model_name)
            current_w_name2 = os.path.join(model_save_path,
                                           args.current_model_name)

            utils.save_ckpt(state, best_question_match < question_match,
                            current_w_name, best_w_name, current_w_name2)
            best_question_match = max(best_question_match, question_match)
            if epoch in args.stage_epoch:
                stage += 1
                lr /= args.lr_decay
                best_w_name = os.path.join(model_save_path,
                                           args.best_model_name)
                model.load_state_dict(torch.load(best_w_name)['state_dict'])
                print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
                utils.adjust_learning_rate(optimizer, lr)

    log.put("Finished training!")
    log.close()