예제 #1
0
def evaluate_lpp(model, src, tgt, config):
    """ evaluate log perplexity WITHOUT decoding
        (i.e., with teacher forcing)
    """
    weight_mask = torch.ones(len(tgt['tok2id']))
    if CUDA:
        weight_mask = weight_mask.cuda()
    weight_mask[tgt['tok2id']['<pad>']] = 0
    loss_criterion = nn.CrossEntropyLoss(weight=weight_mask)
    if CUDA:
        loss_criterion = loss_criterion.cuda()

    losses = []
    for j in range(0, len(src['data']), config['data']['batch_size']):
        # get batch
        input_content, input_aux, output = data.minibatch(src, tgt, j, config['data']['batch_size'], 
                                                          config['data']['max_len'], 
                                                          config['model']['model_type'],
                                                          is_test=True)
        input_content_src, _, srclens, srcmask, _ = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_data_tgt, output_data_tgt, _, _, _ = output

        decoder_logit, decoder_probs = model(input_content_src, srcmask, srclens,
                                             input_ids_aux, auxmask, auxlens, input_data_tgt, mode='train')

        loss = loss_criterion(decoder_logit.contiguous().view(-1, len(tgt['tok2id'])),
                              output_data_tgt.view(-1))
        losses.append(loss.item())

    return np.mean(losses)
def decode_dataset(model, src, tgt, config):
    """Evaluate model."""
    inputs = []
    preds = []
    auxs = []
    ground_truths = []
    for j in range(0, len(src['data']), config['data']['batch_size']):
        sys.stdout.write("\r%s/%s..." % (j, len(src['data'])))
        sys.stdout.flush()

        # get batch
        input_content, input_aux, output = data.minibatch(
            src,
            tgt,
            j,
            config['data']['batch_size'],
            config['data']['max_len'],
            config['model']['model_type'],
            is_test=True)
        input_lines_src, output_lines_src, srclens, srcmask, indices = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output

        # TODO -- beam search
        tgt_pred = decode_minibatch(config['data']['max_len'],
                                    tgt['tok2id']['<s>'], model,
                                    input_lines_src, srclens, srcmask,
                                    input_ids_aux, auxlens, auxmask)

        # convert seqs to tokens
        def ids_to_toks(tok_seqs, id2tok):
            out = []
            # take off the gpu
            tok_seqs = tok_seqs.cpu().numpy()
            # convert to toks, cut off at </s>, delete any start tokens (preds were kickstarted w them)
            for line in tok_seqs:
                toks = [id2tok[x] for x in line]
                if '<s>' in toks:
                    toks.remove('<s>')
                cut_idx = toks.index('</s>') if '</s>' in toks else len(toks)
                out.append(toks[:cut_idx])
            # unsort
            out = data.unsort(out, indices)
            return out

        # convert inputs/preds/targets/aux to human-readable form
        inputs += ids_to_toks(output_lines_src, src['id2tok'])
        preds += ids_to_toks(tgt_pred, tgt['id2tok'])
        ground_truths += ids_to_toks(output_lines_tgt, tgt['id2tok'])

        if config['model']['model_type'] == 'delete':
            auxs += [[str(x)] for x in input_ids_aux.data.cpu().numpy()
                     ]  # because of list comp in inference_metrics()
        elif config['model']['model_type'] == 'delete_retrieve':
            auxs += ids_to_toks(input_ids_aux, tgt['id2tok'])
        elif config['model']['model_type'] == 'seq2seq':
            auxs += ['None' for _ in range(len(tgt_pred))]

    return inputs, preds, ground_truths, auxs
예제 #3
0
def evaluate_lpp(model, src, tgt, config):
    """ evaluate log perplexity WITHOUT decoding
        (i.e., with teacher forcing)
    """
    weight_mask = torch.ones(len(tgt['tok2id']))
    if CUDA:
        weight_mask = weight_mask.cuda()
    weight_mask[tgt['tok2id']['<pad>']] = 0
    loss_criterion = nn.CrossEntropyLoss(weight=weight_mask)
    if CUDA:
        loss_criterion = loss_criterion.cuda()

    losses = []
    for j in range(0, len(src['data']), config['data']['batch_size']):
        sys.stdout.write("\r%s/%s..." % (j, len(src['data'])))
        sys.stdout.flush()

        # get batch
        input_content, input_aux, output, side_info, _ = data.minibatch(
            src,
            tgt,
            j,
            config['data']['batch_size'],
            config['data']['max_len'],
            config,
            is_test=True)
        input_lines_src, _, srclens, srcmask, _ = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output
        side_info, _, _, _, _ = side_info

        decoder_logit, decoder_probs, _, _ = model(input_lines_src,
                                                   input_lines_tgt, srcmask,
                                                   srclens, input_ids_aux,
                                                   auxlens, auxmask, side_info)

        loss = loss_criterion(
            decoder_logit.contiguous().view(-1, len(tgt['tok2id'])),
            output_lines_tgt.view(-1))
        losses.append(loss.data[0])

    return np.mean(losses)
예제 #4
0
    def sinusoid():

        net_init, net_fn = mlp(n_output=1,
                               n_hidden_layer=2,
                               bias_coef=1.0,
                               n_hidden_unit=40,
                               activation='relu',
                               norm='batch_norm')

        rng = random.PRNGKey(42)
        in_shape = (-1, 1)
        out_shape, net_params = net_init(rng, in_shape)

        def loss(params, batch):
            inputs, targets = batch
            predictions = net_fn(params, inputs)
            return np.mean((predictions - targets)**2)

        opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-2,
                                                               mass=0.9)
        opt_update = jit(opt_update)

        @jit
        def step(i, opt_state, batch):
            params = get_params(opt_state)
            g = grad(loss)(params, batch)
            return opt_update(i, g, opt_state)

        task = sinusoid_task(n_support=1000, n_query=100)

        opt_state = opt_init(net_params)
        for i, (x, y) in enumerate(
                minibatch(task['x_train'],
                          task['y_train'],
                          batch_size=256,
                          train_epochs=1000)):
            opt_state = step(i, opt_state, batch=(x, y))
            if i == 0 or (i + 1) % 100 == 0:
                print(
                    f"train loss: {loss(get_params(opt_state), (task['x_train'], task['y_train']))},"
                    f"\ttest loss: {loss(get_params(opt_state), (task['x_test'], task['y_test']))}"
                )
예제 #5
0
def evaluate_rouge(model, src, tgt, config):
    """ 
    evaluate log perplexity WITH decoding
    
    args:
        src: src data object (i.e. data 0, learnt by the model)
        tgt: target data object (i.e. data 0, learnt by the model)
    """
    weight_mask = torch.ones(len(tgt['tok2id']))
    if CUDA:
        weight_mask = weight_mask.cuda()
    weight_mask[tgt['tok2id']['<pad>']] = 0
        
    searcher = models.GreedySearchDecoder(model)

    rouge_list = []
    decoded_results = []
    for j in range(0, len(src['data'])):
        # batch_size = 1
        input_content, input_aux, output = data.minibatch(src, src, j, 1, 
                                             config['data']['max_len'], 
                                             config['model']['model_type'])
        input_content_src, _, srclens, srcmask, _ = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_data_tgt, output_data_tgt, _, _, _ = output
        
        
        decoder_logit, decoded_data_tgt = searcher(input_content_src, srcmask, srclens,
                                                   input_ids_aux, auxmask, auxlens,
                                                   20, tgt['tok2id']['<s>'])
        decoded_sent = id2word(decoded_data_tgt, tgt)
        gold_sent = id2word(output_data_tgt, tgt)
        rouge = rouge_2(gold_sent, decoded_sent)
        rouge_list.append(rouge)
        decoded_results.append(decoded_sent)
        
        #print('Source content sentence:'+gold_sent)
        #print('Decoded data sentence:'+decoded_sent)

    return np.mean(rouge_list), decoded_results
예제 #6
0
def decode_dataset(model, src, tgt, config, k=20):
    """Evaluate model."""
    inputs = []
    preds = []
    top_k_preds = []
    auxs = []
    ground_truths = []
    raw_srcs = []
    for j in range(0, len(src['data']), config['data']['batch_size']):
        sys.stdout.write("\r%s/%s..." % (j, len(src['data'])))
        sys.stdout.flush()

        # get batch
        input_content, input_aux, output, side_info, raw_src = data.minibatch(
            src,
            tgt,
            j,
            config['data']['batch_size'],
            config['data']['max_len'],
            config,
            is_test=True)
        input_lines_src, output_lines_src, srclens, srcmask, indices = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output
        _, raw_src, _, _, _ = raw_src
        side_info, _, _, _, _ = side_info

        # TODO -- beam search
        tgt_pred_top_k = decode_minibatch(config['data']['max_len'],
                                          tgt['tok2id']['<s>'],
                                          model,
                                          input_lines_src,
                                          srclens,
                                          srcmask,
                                          input_ids_aux,
                                          auxlens,
                                          auxmask,
                                          side_info,
                                          k=k)

        # convert inputs/preds/targets/aux to human-readable form
        inputs += ids_to_toks(output_lines_src, src['id2tok'], indices)
        ground_truths += ids_to_toks(output_lines_tgt, tgt['id2tok'], indices)
        raw_srcs += ids_to_toks(raw_src, src['id2tok'], indices)

        # TODO -- refactor this stuff!! it's shitty
        # get the "offical" predictions from the model
        pred_toks, pred_lens = ids_to_toks(tgt_pred_top_k[:, :, 0],
                                           tgt['id2tok'],
                                           indices,
                                           save_cuts=True)
        preds += pred_toks
        # now get all the other top-k prediction levels
        top_k_pred = [pred_toks]
        for i in range(k - 1):
            top_k_pred.append(
                ids_to_toks(tgt_pred_top_k[:, :, i + 1],
                            tgt['id2tok'],
                            indices,
                            cuts=pred_lens))
        # top_k_pred is [k, batch, length] where length is ragged
        # but we want it in [batch, length, k]. Manual transpose b/c ragged :(
        batch_size = len(top_k_pred[0])  # could be variable at test time
        pred_lens = data.unsort(pred_lens, indices)
        top_k_pred_transposed = [[] for _ in range(batch_size)]
        for bi in range(batch_size):
            for ti in range(pred_lens[bi]):
                top_k_pred_transposed[bi] += [[
                    top_k_pred[ki][bi][ti] for ki in range(k)
                ]]
        top_k_preds += top_k_pred_transposed

        if config['model']['model_type'] == 'delete':
            auxs += [[str(x)] for x in input_ids_aux.data.cpu().numpy()
                     ]  # because of list comp in inference_metrics()
        elif config['model']['model_type'] == 'delete_retrieve':
            auxs += ids_to_toks(input_ids_aux, tgt['id2tok'], indices)
        elif config['model']['model_type'] == 'seq2seq':
            auxs += ['None' for _ in range(batch_size)]

    return inputs, preds, top_k_preds, ground_truths, auxs, raw_srcs
예제 #7
0
def train(config, working_dir):
    # load data
    src, tok_weights_dict = data.gen_train_data(src=config['data']['src'], tgt=config['data']['tgt'], config=config)
    src_dev, tgt_dev = data.gen_dev_data(src=config['data']['src_dev'], tgt=config['data']['tgt_dev'], 
                                         tok_weights_dict=tok_weights_dict, config=config)
    logging.info('Reading data done!')
    
    # build model
    model = build_model(src, config)
    logging.info('MODEL HAS %s params' %  model.count_params())
    
    # get most recent checkpoint
    model, start_epoch = attempt_load_model(model=model, checkpoint_dir=working_dir)
    
    # initialize loss criterion
    weight_mask = torch.ones(len(src['tok2id']))
    weight_mask[src['tok2id']['<pad>']] = 0
    loss_criterion = nn.CrossEntropyLoss(weight=weight_mask)
    
    if CUDA:
        model = model.cuda()
        weight_mask = weight_mask.cuda()
        loss_criterion = loss_criterion.cuda()
        
    # initialize optimizer
    if config['training']['optimizer'] == 'adam':
        lr = config['training']['learning_rate']
        optimizer = optim.Adam(model.parameters(), lr=lr)
    elif config['training']['optimizer'] == 'sgd':
        lr = config['training']['learning_rate']
        optimizer = optim.SGD(model.parameters(), lr=lr)
    elif config['training']['optimizer']=='adadelta':
        lr = config['training']['learning_rate']
        optimizer = optim.Adadelta(model.parameters(), lr=lr)
    else:
        raise NotImplementedError("Learning method not recommend for task")
    
    
    # start training
    start_since_last_report = time.time()
    losses_since_last_report = []
    best_metric = 0.0
    cur_metric = 0.0    # log perplexity or BLEU
    dev_loss = 0.0
    dev_rouge = 0.0
    num_batches = len(src['content']) // config['data']['batch_size']

    for epoch in range(start_epoch, config['training']['epochs']):
        if cur_metric > best_metric:
            # rm old checkpoint
            for ckpt_path in glob.glob(working_dir + '/model.*'):
                os.system("rm %s" % ckpt_path)
            # replace with new checkpoint
            torch.save(model.state_dict(), working_dir + '/model.%s.ckpt' % epoch)
    
            best_metric = cur_metric
    
        for i in range(0, len(src['content']), config['data']['batch_size']):
            batch_idx = i // config['data']['batch_size']
            
            # generate current training data batch
            input_content, input_aux, output = data.minibatch(src, src, i, config['data']['batch_size'],
                                                              config['data']['max_len'], config['model']['model_type'])
            input_content_src, _, srclens, srcmask, _ = input_content
            input_ids_aux, _, auxlens, auxmask, _ = input_aux
            input_data_tgt, output_data_tgt, _, _, _ = output
            
            # train the model with current training data batch
            decoder_logit, decoder_probs = model(input_content_src, srcmask, srclens,
                                                 input_ids_aux, auxmask, auxlens, input_data_tgt, mode='train')
            # setup the optimizer
            optimizer.zero_grad()
            loss = loss_criterion(decoder_logit.contiguous().view(-1, len(src['tok2id'])),
                                  output_data_tgt.view(-1))
            losses_since_last_report.append(loss.item())
            
            # perform backpropagation
            loss.backward()
            
            # clip gradients            
            _ = nn.utils.clip_grad_norm_(model.parameters(), config['training']['max_norm'])
            
            # update model params
            optimizer.step()
            
            # print out the training information
            if batch_idx % config['training']['batches_per_report'] == 0:
                s = float(time.time() - start_since_last_report)
                wps = (config['data']['batch_size'] * config['training']['batches_per_report']) / s
                avg_loss = np.mean(losses_since_last_report)
                info = (epoch, batch_idx, num_batches, wps, avg_loss, dev_loss, dev_rouge)
                cur_metric = dev_rouge
                logging.info('EPOCH: %s ITER: %s/%s WPS: %.2f LOSS: %.4f DEV_LOSS: %.4f DEV_ROUGE: %.4f' % info)
                start_since_last_report = time.time()
                losses_since_last_report = []

        # start evaluate the model on entire dev set
        logging.info('EPOCH %s COMPLETE. VALIDATING...' % epoch)
        model.eval()
        
        # compute validation loss
        logging.info('Computing dev_loss on validation data ...')
        dev_loss = evaluation.evaluate_lpp(model=model, src=tgt_dev, tgt=tgt_dev, config=config)
        dev_rouge, decoded_sents = evaluation.evaluate_rouge(model=model, src=src_dev, tgt=tgt_dev, config=config)
        logging.info('...done!')
    
        # switch back to train mode
        model.train()
예제 #8
0
        # replace with new checkpoint
        torch.save(model.state_dict(), working_dir + '/model.%s.ckpt' % epoch)

        best_metric = cur_metric
        best_epoch = epoch - 1

    losses = []
    for i in range(0, len(src['data']), batch_size):
        if args.test:
            continue
        if args.overfit:
            i = 0

        batch_idx = i / batch_size

        input_content, input_aux, output, side_info, _ = data.minibatch(
            src, tgt, i, batch_size, max_length, config)
        input_lines_src, _, srclens, srcmask, _ = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output
        side_info, _, _, _, _ = side_info

        decoder_logit, decoder_probs, side_logit, side_loss = model(
            input_lines_src, input_lines_tgt, srcmask, srclens,
            input_ids_aux, auxlens, auxmask, side_info)

        optimizer.zero_grad()


        generator_loss = loss_criterion(
            decoder_logit.contiguous().view(-1, tgt_vocab_size),
            output_lines_tgt.view(-1)
예제 #9
0
            os.system("rm %s" % ckpt_path)
        # replace with new checkpoint
        torch.save(model.state_dict(), working_dir + '/model.%s.ckpt' % epoch)

        best_metric = cur_metric
        best_epoch = epoch - 1

    losses = []
    for i in range(0, len(src['content']), batch_size):

        if args.overfit:
            i = 50

        batch_idx = i / batch_size

        input_content, input_aux, output = data.minibatch(
            src, tgt, i, batch_size, max_length, config['model']['model_type'])
        input_lines_src, _, srclens, srcmask, _ = input_content
        input_ids_aux, _, auxlens, auxmask, _ = input_aux
        input_lines_tgt, output_lines_tgt, _, _, _ = output

        decoder_logit, decoder_probs = model(input_lines_src, input_lines_tgt,
                                             srcmask, srclens, input_ids_aux,
                                             auxlens, auxmask)

        optimizer.zero_grad()

        loss = loss_criterion(
            decoder_logit.contiguous().view(-1, tgt_vocab_size),
            output_lines_tgt.view(-1))
        losses.append(loss.data.item())
        losses_since_last_report.append(loss.data.item())
예제 #10
0
def my_decode_dataset(model, src, tgt, config):
    searcher = models.GreedySearchDecoder(model)
    rouge_list = []
    initial_inputs = []
    preds = []
    ground_truths = []
    auxs = []
    
    for j in range(0, len(src['data'])):
        if j%100 == 0:
            logging.info('Finished decoding data: %d/%d ...'% (j, len(src['data'])))
        
        # batch_size = 1
        inputs, _, outputs = data.minibatch(src, tgt, j, 1, 
                                            config['data']['max_len'], 
                                            config['model']['model_type'], 
                                            is_test=True)
        input_content_src, _, srclens, srcmask, _ = inputs
        _, output_data_tgt, tgtlens, tgtmask, _ = outputs
       
        
        tgt_dist_measurer = tgt['dist_measurer']
        related_content_tgt = tgt_dist_measurer.most_similar(j, n=3)   # list of n seq_str
        # related_content_tgt = source_content_str, target_content_str, target_att_str, idx, score
        
        # Put all the retrieved attributes together
        retrieved_attrs_set = set()
        for single_data_tgt in related_content_tgt:
            sp = single_data_tgt[2].split()
            for attr in sp:
                retrieved_attrs_set.add(attr)
                    
        retrieved_attrs = ' '.join(retrieved_attrs_set)
        
        input_ids_aux, auxlens, auxmask = word2id(retrieved_attrs, None, tgt, config['data']['max_len'])
        
        n_decoded_sents = []
        
        input_ids_aux = Variable(torch.LongTensor(input_ids_aux))
        auxlens = Variable(torch.LongTensor(auxlens))
        auxmask = Variable(torch.LongTensor(auxmask))
            
        if CUDA:
            input_ids_aux = input_ids_aux.cuda()
            auxlens = auxlens.cuda()
            auxmask = auxmask.cuda()
            
        _, decoded_data_tgt = searcher(input_content_src, srcmask, srclens,
                                           input_ids_aux, auxmask, auxlens,
                                           20, tgt['tok2id']['<s>'])
        
        decode_sent = id2word(decoded_data_tgt, tgt)
        n_decoded_sents.append(decode_sent)
        #print('Source content sentence:'+''.join(related_content_tgt[0][1]))
        #print('Decoded data sentence:'+n_decoded_sents[0])
        input_sent = id2word(input_content_src, src)
        initial_inputs.append(input_sent.split())
        pred_sent = n_decoded_sents[0]
        preds.append(pred_sent.split())
        truth_sent = id2word(output_data_tgt, tgt)
        ground_truths.append(truth_sent.split())
        aux_sent = id2word(input_ids_aux, src)
        auxs.append(aux_sent.split())
        rouge_cur = rouge_2(truth_sent, pred_sent)
        rouge_list.append(rouge_cur)
    
    return searcher, rouge_list, initial_inputs, preds, ground_truths, auxs