Esempio n. 1
0
def main():
    log.info('[program starts.]')
    log.info('seed: {}'.format(args.seed))
    log.info(str(vars(args)))
    opt = vars(args)  # changing opt will change args
    train, train_embedding, opt = load_train_data(opt)
    dev, dev_embedding, dev_answer = load_dev_data(opt)
    opt['num_features'] += args.explicit_dialog_ctx * (
        args.use_dialog_act * 3 + 2)  # dialog_act + previous answer
    if opt['use_elmo'] == False:
        opt['elmo_batch_size'] = 0
    log.info('[Data loaded.]')

    if args.resume:
        log.info('[loading previous model...]')
        if args.cuda:
            checkpoint = torch.load(args.resume,
                                    map_location={'cpu': 'cuda:0'})
        else:
            checkpoint = torch.load(args.resume,
                                    map_location={'cuda:0': 'cpu'})
        if args.resume_options:
            opt = checkpoint['config']
        state_dict = checkpoint['state_dict']
        model = QAModel(opt, train_embedding, state_dict)
        epoch_0 = checkpoint['epoch'] + 1
        for i in range(checkpoint['epoch']):
            random.shuffle(list(range(len(train))))  # synchronize random seed
        if args.reduce_lr:
            lr_decay(model.optimizer, lr_decay=args.reduce_lr)
    else:
        model = QAModel(opt, train_embedding)
        epoch_0 = 1

    if args.pretrain:
        pretrain_model = torch.load(args.pretrain)
        state_dict = pretrain_model['state_dict']['network']

        model.get_pretrain(state_dict)

    model.setup_eval_embed(dev_embedding)
    log.info("[dev] Total number of params: {}".format(model.total_param))

    if args.cuda:
        model.cuda()

    if args.resume:
        batches = BatchGen_QuAC(dev,
                                batch_size=args.batch_size,
                                evaluation=True,
                                gpu=args.cuda,
                                dialog_ctx=args.explicit_dialog_ctx,
                                use_dialog_act=args.use_dialog_act,
                                use_bert=args.use_bert)
        predictions, no_ans_scores = [], []
        for batch in batches:
            phrases, noans = model.predict(batch)
            predictions.extend(phrases)
            no_ans_scores.extend(noans)
        f1, na, thresh = find_best_score_and_thresh(predictions, dev_answer,
                                                    no_ans_scores)
        log.info("[dev F1: {} NA: {} TH: {}]".format(f1, na, thresh))
        best_val_score, best_na, best_thresh = f1, na, thresh
    else:
        best_val_score, best_na, best_thresh = 0.0, 0.0, 0.0

    aggregate_grad_steps = 1
    if opt['use_bert']:
        aggregate_grad_steps = opt['aggregate_grad_steps']

    for epoch in range(epoch_0, epoch_0 + args.epoches):

        log.warning('Epoch {}'.format(epoch))
        # train
        batches = BatchGen_QuAC(train,
                                batch_size=args.batch_size,
                                gpu=args.cuda,
                                dialog_ctx=args.explicit_dialog_ctx,
                                use_dialog_act=args.use_dialog_act,
                                precompute_elmo=args.elmo_batch_size //
                                args.batch_size,
                                use_bert=args.use_bert)
        start = datetime.now()

        total_batches = len(batches)
        loss = 0
        model.optimizer.zero_grad()
        if opt['finetune_bert']:
            model.bertadam.zero_grad()

        for i, batch in enumerate(batches):
            loss += model.update(batch)
            if (i + 1) % aggregate_grad_steps == 0 or total_batches == (i + 1):
                # Update the gradients
                model.take_step()
                loss = 0
            if i % args.log_per_updates == 0:
                log.info(
                    'updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.format(
                        model.updates, model.train_loss.avg,
                        str((datetime.now() - start) / (i + 1) *
                            (len(batches) - i - 1)).split('.')[0]))
        # eval
        if epoch % args.eval_per_epoch == 0:
            batches = BatchGen_QuAC(dev,
                                    batch_size=args.batch_size,
                                    evaluation=True,
                                    gpu=args.cuda,
                                    dialog_ctx=args.explicit_dialog_ctx,
                                    use_dialog_act=args.use_dialog_act,
                                    precompute_elmo=args.elmo_batch_size //
                                    args.batch_size,
                                    use_bert=args.use_bert)
            predictions, no_ans_scores = [], []
            for batch in batches:
                phrases, noans = model.predict(batch)
                predictions.extend(phrases)
                no_ans_scores.extend(noans)
            f1, na, thresh = find_best_score_and_thresh(
                predictions, dev_answer, no_ans_scores)

        # save
        if args.save_best_only:
            if f1 > best_val_score:
                best_val_score, best_na, best_thresh = f1, na, thresh
                model_file = os.path.join(model_dir, 'best_model.pt')
                model.save(model_file, epoch)
                log.info('[new best model saved.]')
        else:
            model_file = os.path.join(model_dir,
                                      'checkpoint_epoch_{}.pt'.format(epoch))
            model.save(model_file, epoch)
            if f1 > best_val_score:
                best_val_score, best_na, best_thresh = f1, na, thresh
                copyfile(os.path.join(model_dir, model_file),
                         os.path.join(model_dir, 'best_model.pt'))
                log.info('[new best model saved.]')

        log.warning(
            "Epoch {} - dev F1: {:.3f} NA: {:.3f} TH: {:.3f} (best F1: {:.3f} NA: {:.3f} TH: {:.3f})"
            .format(epoch, f1, na, thresh, best_val_score, best_na,
                    best_thresh))
Esempio n. 2
0
def main():
    log.info('[program starts.]')
    opt = vars(args)  # changing opt will change args

    train, opt = load_train_data(opt)
    dev, dev_answer = load_dev_data(opt)

    # opt['num_features']=4
    # explicit_dialog_ctx=2
    # use_dialog_act = False
    opt['num_features'] += args.explicit_dialog_ctx * (args.use_dialog_act * 3 + 2)  # dialog_act + previous answer
    if opt['use_elmo'] == False:
        opt['elmo_batch_size'] = 0
    log.info('[Data loaded.]')

    if args.resume:
        log.info('[loading previous model...]')
        checkpoint = torch.load(args.resume)
        if args.resume_options:
            opt = checkpoint['config']
        state_dict = checkpoint['state_dict']
        model = QAModel(opt, state_dict=state_dict)
        epoch_0 = checkpoint['epoch'] + 1
        for i in range(checkpoint['epoch']):
            random.shuffle(list(range(len(train))))  # synchronize random seed
        if args.reduce_lr:
            lr_decay(model.optimizer, lr_decay=args.reduce_lr)
    else:
        model = QAModel(opt)
        epoch_0 = 1

    if args.pretrain:
        pretrain_model = torch.load(args.pretrain)
        state_dict = pretrain_model['state_dict']['network']

        model.get_pretrain(state_dict)

    log.info("[dev] Total number of params: {}".format(model.total_param))

    if args.cuda:
        model.cuda()

    if args.resume:
        batches = BatchGen_QuAC(dev, batch_size=args.batch_size, evaluation=True, gpu=args.cuda,
                                dialog_ctx=args.explicit_dialog_ctx, use_dialog_act=args.use_dialog_act)
        predictions, no_ans_scores = [], []
        for batch in batches:
            phrases, noans = model.predict(batch)
            predictions.extend(phrases)
            no_ans_scores.extend(noans)
        f1, na, thresh = find_best_score_and_thresh(predictions, dev_answer, no_ans_scores)
        log.info("[dev F1: {} NA: {} TH: {}]".format(f1, na, thresh))
        best_val_score, best_na, best_thresh = f1, na, thresh
    else:
        best_val_score, best_na, best_thresh = 0.0, 0.0, 0.0

    for epoch in range(epoch_0, epoch_0 + args.epoches):
        log.warning('Epoch {}'.format(epoch))
        # train
        batches = BatchGen_QuAC(train, batch_size=args.batch_size, gpu=args.cuda, dialog_ctx=args.explicit_dialog_ctx,
                                use_dialog_act=args.use_dialog_act,
                                precompute_elmo=args.elmo_batch_size // args.batch_size)
        start = datetime.now()
        # maxlen = 0
        # for i, batch in enumerate(batches):
        #     for item in batch:
        #         context_id = batch[0]
        #         if (len(context_id) > maxlen):
        #             maxlen = len(context_id)
        # print('maxlen:', maxlen)
        # exit(0)
        for i, batch in enumerate(batches):
            model.update(batch)
            if i % args.log_per_updates == 0:
                log.info('updates[{0:6}] train loss[{1:.5f}] remaining[{2}]'.format(
                    model.updates, model.train_loss.avg,
                    str((datetime.now() - start) / (i + 1) * (len(batches) - i - 1)).split('.')[0]))

        # eval
        if epoch % args.eval_per_epoch == 0:
            batches = BatchGen_QuAC(dev, batch_size=args.batch_size, evaluation=True, gpu=args.cuda,
                                    dialog_ctx=args.explicit_dialog_ctx, use_dialog_act=args.use_dialog_act,
                                    precompute_elmo=args.elmo_batch_size // args.batch_size)
            predictions, no_ans_scores = [], []
            for batch in batches:
                phrases, noans = model.predict(batch)
                predictions.extend(phrases)
                no_ans_scores.extend(noans)
            f1, na, thresh = find_best_score_and_thresh(predictions, dev_answer, no_ans_scores)

        # save
        if args.save_best_only:
            if f1 > best_val_score:
                best_val_score, best_na, best_thresh = f1, na, thresh
                model_file = os.path.join(model_dir, 'best_model.pt')
                model.save(model_file, epoch)
                log.info('[new best model saved.]')
        else:
            model_file = os.path.join(model_dir, 'checkpoint_epoch_{}.pt'.format(epoch))
            model.save(model_file, epoch)
            if f1 > best_val_score:
                best_val_score, best_na, best_thresh = f1, na, thresh
                copyfile(os.path.join(model_dir, model_file),
                         os.path.join(model_dir, 'best_model.pt'))
                log.info('[new best model saved.]')

        log.warning(
            "Epoch {} - dev F1: {:.3f} NA: {:.3f} TH: {:.3f} (best F1: {:.3f} NA: {:.3f} TH: {:.3f})".format(epoch, f1,
                                                                                                             na, thresh,
                                                                                                             best_val_score,
                                                                                                             best_na,
                                                                                                             best_thresh))