示例#1
0
    def buildData(self, srcBatch, goldBatch):
        # This needs to be the same as preprocess.py.
        srcData = [self.src_dict.convertToIdx(
            b, onmt.Constants.UNK_WORD) for b in srcBatch]
        tgtData = None
        if goldBatch:
            tgtData = [self.tgt_dict.convertToIdx(b,
                                                  onmt.Constants.UNK_WORD,
                                                  onmt.Constants.BOS_WORD,
                                                  onmt.Constants.EOS_WORD) for b in goldBatch]

        return memories.Dataset(srcData, tgtData, self.opt.batch_size,
                                self.opt.cuda, 1, volatile=True)
示例#2
0
    def build_data(self, srcBatch, tgtBatch):
        srcData = [
            self.src_dict.convertToIdx(b, onmt.Constants.UNK_WORD)
            for b in srcBatch
        ]

        tgtData = None
        if tgtBatch:
            tgtData = [
                self.tgt_dict.convertToIdx(b, onmt.Constants.UNK_WORD,
                                           onmt.Constants.BOS_WORD,
                                           onmt.Constants.EOS_WORD)
                for b in tgtBatch
            ]

        return memories.Dataset(srcData,
                                tgtData,
                                self.opt.batch_size,
                                self.opt.cuda,
                                1,
                                volatile=True)
示例#3
0
def main():
    if torch.cuda.is_available() and not opt.gpus:
        print(
            "WARNING: You have a CUDA device, so you should probably run with -gpus 0"
        )

    if opt.gpus:
        cuda.set_device(opt.gpus[0])

    print(opt)

    if opt.seed > 0:
        torch.manual_seed(opt.seed)
    print("Loading data from '%s'" % opt.data)

    dataset = torch.load(opt.data)

    dict_checkpoint = (opt.train_from
                       if opt.train_from else opt.train_from_state_dict)
    if dict_checkpoint:
        print('Loading dicts from checkpoint at %s' % dict_checkpoint)
        checkpoint = torch.load(dict_checkpoint)
        dataset['dicts'] = checkpoint['dicts']

    if opt.keys or opt.acts:
        trainData = memories.Key_Dataset(dataset['train'], opt.batch_size,
                                         opt.gpus, opt.context_size)
        validData = memories.Key_Dataset(dataset['valid'],
                                         opt.batch_size,
                                         opt.gpus,
                                         opt.context_size,
                                         volatile=True)
        nr_train_points = len(dataset['train']['src_utts'])

    else:
        trainData = memories.Dataset(dataset['train']['src'],
                                     dataset['train']['tgt'], opt.batch_size,
                                     opt.gpus, opt.context_size)
        validData = memories.Dataset(dataset['valid']['src'],
                                     dataset['valid']['tgt'],
                                     opt.batch_size,
                                     opt.gpus,
                                     opt.context_size,
                                     volatile=True)
        nr_train_points = len(dataset['train']['src'])

    dicts = dataset['dicts']
    print(' * vocabulary size. source = %d; target = %d' %
          (dicts['src'].size(), dicts['tgt'].size()))
    print(' * number of training sentences. %d' % nr_train_points)
    print(' * maximum batch size. %d' % opt.batch_size)

    print('Building model...')

    model = memories.hier_model.HierModel(opt, dicts)

    generator = nn.Sequential(
        nn.Linear(opt.word_vec_size, dicts['tgt'].size()), nn.LogSoftmax())

    if opt.train_from:
        print('Loading model from checkpoint at %s' % opt.train_from)
        chk_model = checkpoint['model']
        # generator_state_dict = chk_model.generator.state_dict()
        model_state_dict = {
            k: v
            for k, v in chk_model.state_dict().items() if 'generator' not in k
        }
        model.load_state_dict(model_state_dict)
        # generator.load_state_dict(generator_state_dict)
        opt.start_epoch = checkpoint['epoch'] + 1

    if opt.train_from_state_dict:
        print('Loading model from checkpoint at %s' %
              opt.train_from_state_dict)
        model.load_state_dict(checkpoint['model'])
        generator.load_state_dict(checkpoint['generator'])
        opt.start_epoch = checkpoint['epoch'] + 1

    if len(opt.gpus) >= 1:
        model.cuda()
        generator.cuda()
    else:
        model.cpu()
        generator.cpu()

    if len(opt.gpus) > 1:
        model = nn.DataParallel(model, device_ids=opt.gpus, dim=1)
        generator = nn.DataParallel(generator, device_ids=opt.gpus, dim=0)

    model.generator = generator

    if not opt.train_from_state_dict and not opt.train_from:
        for p in model.parameters():
            p.data.uniform_(-opt.param_init, opt.param_init)

        # encoder.load_pretrained_vectors(opt)
        # decoder.load_pretrained_vectors(opt)

        optim = onmt.Optim(opt.optim,
                           opt.learning_rate,
                           opt.max_grad_norm,
                           lr_decay=opt.learning_rate_decay,
                           start_decay_at=opt.start_decay_at)
    else:
        print('Loading optimizer from checkpoint:')
        optim = checkpoint['optim']
        print(optim)

    optim.set_parameters(model.parameters())

    if opt.train_from or opt.train_from_state_dict:
        optim.optimizer.load_state_dict(
            checkpoint['optim'].optimizer.state_dict())

    nParams = sum([p.nelement() for p in model.parameters()])
    print('* number of parameters: %d' % nParams)

    if opt.gather_net_data:
        # , opt.n_samples)
        return gather_data(model, validData, dataset['dicts'])

    low_ppl, best_e, trn_ppls, val_ppls, checkpoint = trainModel(
        model, trainData, validData, dataset, optim)
    return low_ppl, best_e, trn_ppls, val_ppls, checkpoint, opt, nParams
示例#4
0
def main():
    opt = parser.parse_args()
    opt.cuda = opt.gpu > -1
    if opt.cuda:
        torch.cuda.set_device(opt.gpu)
    predScoreTotal, predWordsTotal, goldScoreTotal, goldWordsTotal, count = 0, 0, 0, 0, 0

    woz = memories.HierWOZ(opt)

    dataset = torch.load(opt.data)
    data = memories.Dataset(dataset['valid'],
                            opt.batch_size,
                            opt.cuda,
                            opt.context_size,
                            volatile=True)

    def convert_to_words(word_ids):
        padding = word_ids.data.eq(onmt.Constants.PAD)
        sen_start = list(padding).index(0)
        words = woz.src_dict.convertToLabels(word_ids.data[sen_start:],
                                             onmt.Constants.EOS)

        return words

    try:
        res = torch.load(opt.output)
    except FileNotFoundError:
        res = {
            i: {
                'src': None,
                'tgt': None,
                'cont': None,
                'attn': {
                    'base': None,
                    'pred': None,
                    'score': None,
                    'loc': None
                },
                'nse': {
                    'base': None,
                    'pred': None,
                    'score': None,
                    'utt_attn': None,
                    'cont_attn': None
                }
            }
            for i in range(len(data) * 30)
        }

    for i in range(len(data)):

        batch = data[i]
        base_out, predBatch, predScore, goldScore, attn_locs = woz.reply(batch)
        predScoreTotal += sum(score[0] for score in predScore)
        predWordsTotal += sum(len(x[0]) for x in predBatch)
        src_utts, src_cont, dacts, tgt_batch = batch

        for j, (base, pred, context, tgt, score, utt_locs,
                cont_locs) in enumerate(
                    zip(base_out.split(1), predBatch, src_utts.split(1, 2),
                        tgt_batch.split(1, 1), predScore,
                        attn_locs[0].split(1, 1), attn_locs[1].split(1, 1))):
            '''
            for j,(base, pred,src, context, tgt, score) in enumerate(zip(
                    base_out.split(1),
                    predBatch,src_batch.split(1,1),
                    context_batch.split(1,2),tgt_batch.split(1,1),
                    predScore )):
            '''
            src = context[-1]
            if score.data[0] > -30:
                res_i = i * 30 + j

                base = base.squeeze()
                if torch.sum(base.eq(onmt.Constants.EOS)):
                    eos = list(base).index(onmt.Constants.EOS)
                else:
                    eos = base.size(0)

                base = woz.tgt_dict.convertToLabels(base[:eos],
                                                    onmt.Constants.EOS)

                print(' ============ \n')
                src_sen = convert_to_words(src.squeeze(1))
                print(' === src : \n * %s\n' % ' '.join(src_sen))
                print(' === context : \n')
                cont = []
                for ci, c in enumerate(context.split(1)):
                    cont_i = convert_to_words(c.squeeze())
                    print(' %d : %s' % (ci, ' '.join(cont_i)))
                    cont += [cont_i]
                tgt_sen = woz.tgt_dict.convertToLabels(tgt.squeeze().data,
                                                       onmt.Constants.EOS)
                print(' === tgt : \n * %s\n' % ' '.join(tgt_sen[1:-1]))
                print(' === base : \n * %s\n' % ' '.join(base))
                print(' === pred (%.4f): \n * %s\n' %
                      (score.data[0], ' '.join(pred)))

                if res[res_i]['src'] is None:
                    res[res_i]['src'] = src_sen
                else:
                    assert res[res_i]['src'] == src_sen
                if res[res_i]['cont'] is None:
                    res[res_i]['cont'] = cont
                else:
                    assert res[res_i]['cont'] == cont

                if res[res_i]['tgt'] is None:
                    res[res_i]['tgt'] = tgt_sen
                else:
                    assert res[res_i]['tgt'] == tgt_sen

                if woz.mem == 'reasoning_nse':
                    res[res_i]['nse']['base'] = base
                    res[res_i]['nse']['pred'] = pred
                    res[res_i]['nse']['score'] = score.data[0]
                    utt_locs = utt_locs.squeeze()
                    #print(utt_locs)
                    res[res_i]['nse'][
                        'utt_attn'] = utt_locs  #.masked_select(utt_locs.ne(1)).view(5,-1)
                    cont_locs = cont_locs.squeeze()
                    res[res_i]['nse'][
                        'cont_attn'] = cont_locs  #.masked_select(cont_locs.gt(0)).view(5,-1)

                ch = input(' --> ')
                if ch == 'q':
                    break
    torch.save(res, opt.output)