Ejemplo n.º 1
0
def test(config):
    options = optionsLoader(LOG, config.optionsFrame, disp=True)
    Vocab = loadFromPKL("vocab/" + config.data + ".Vocab")
    Best_Model = torch.load("model/" + config.model + "_" + config.data +
                            "/model_best.pth.tar")

    if options['network']['type'] == 'LSTM2_MeanDiff_FlatParse':
        emb_init = np.concatenate([
            random_weights(2 + options['network']['n_nt'],
                           options['network']['Embedding']['n_dim'], 0.01),
            Vocab.i2e
        ],
                                  axis=0)
    elif options['network']['type'] == 'LSTM2_MeanDiff_deRNNG':
        emb_init = np.concatenate([
            random_weights(3, options['network']['Embedding']['n_dim'], 0.01),
            Vocab.i2e
        ],
                                  axis=0)
    else:
        emb_init = Vocab.i2e

    net = framework(options, LOG, emb_tok_init=torch.from_numpy(emb_init))
    net.load_state_dict(Best_Model['state_dict'])

    if torch.cuda.is_available():
        LOG.log('Using Device: %s' %
                torch.cuda.get_device_name(torch.cuda.current_device()))
        net = net.cuda()

    print(net)
    f_in = open(config.inputFile, 'r')
    f = open('summary.txt', 'w')
    fp = open('parse.txt', 'w')

    Annotator = PyCoreNLP()
    for idx, line in enumerate(f_in):
        source_ = line.strip()
        anno = Annotator.annotate(source_.encode('ascii', 'ignore'),
                                  eolonly=True)
        source_token = []
        for sent in anno['sentences']:
            for token in sent["tokens"]:
                source_token.append(token["originalText"].lower())
        source = ListOfWord2ListOfIndex(source_token, Vocab)
        [text, parse] = net.genSummary([source], Vocab, source_token)
        print(idx)
        print(text[0])
        print(parse[0])
        print(text[0], file=f)
        print(parse[0], file=fp)
Ejemplo n.º 2
0
def train(config):
    net = BertForMaskedLM.from_pretrained(config.model)
    lossFunc = KLDivLoss(config)

    if torch.cuda.is_available():
        net = net.cuda()
        lossFunc = lossFunc.cuda()

        if config.dataParallel:
            net = DataParallelModel(net)
            lossFunc = DataParallelCriterion(lossFunc)

    options = optionsLoader(LOG, config.optionFrames, disp=False)
    Tokenizer = BertTokenizer.from_pretrained(config.model)
    prepareFunc = prepare_data

    trainSet = Dataset('train', config.batch_size,
                       lambda x: len(x[0]) + len(x[1]), prepareFunc, Tokenizer,
                       options['dataset'], LOG, 'train')
    validSet = Dataset('valid', config.batch_size,
                       lambda x: len(x[0]) + len(x[1]), prepareFunc, Tokenizer,
                       options['dataset'], LOG, 'valid')

    print(trainSet.__len__())

    Q = []
    best_vloss = 1e99
    counter = 0
    lRate = config.lRate

    prob_src = config.prob_src
    prob_tgt = config.prob_tgt

    num_train_optimization_steps = trainSet.__len__(
    ) * options['training']['stopConditions']['max_epoch']
    param_optimizer = list(net.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=lRate,
                         e=1e-9,
                         t_total=num_train_optimization_steps,
                         warmup=0.0)

    for epoch_idx in range(options['training']['stopConditions']['max_epoch']):
        total_seen = 0
        total_similar = 0
        total_unseen = 0
        total_source = 0

        trainSet.setConfig(config, prob_src, prob_tgt)
        trainLoader = data.DataLoader(dataset=trainSet,
                                      batch_size=1,
                                      shuffle=True,
                                      num_workers=config.dataLoader_workers,
                                      pin_memory=True)

        validSet.setConfig(config, 0.0, prob_tgt)
        validLoader = data.DataLoader(dataset=validSet,
                                      batch_size=1,
                                      shuffle=False,
                                      num_workers=config.dataLoader_workers,
                                      pin_memory=True)

        for batch_idx, batch_data in enumerate(trainLoader):
            if (batch_idx + 1) % 10000 == 0:
                gc.collect()
            start_time = time.time()

            net.train()

            inputs, positions, token_types, labels, masks, batch_seen, batch_similar, batch_unseen, batch_source = batch_data

            inputs = inputs[0].cuda()
            positions = positions[0].cuda()
            token_types = token_types[0].cuda()
            labels = labels[0].cuda()
            masks = masks[0].cuda()
            total_seen += batch_seen
            total_similar += batch_similar
            total_unseen += batch_unseen
            total_source += batch_source

            n_token = int((labels.data != 0).data.sum())

            predicts = net(inputs, positions, token_types, masks)
            loss = lossFunc(predicts, labels, n_token).sum()

            Q.append(float(loss))
            if len(Q) > 200:
                Q.pop(0)
            loss_avg = sum(Q) / len(Q)

            optimizer.zero_grad()
            loss.backward()

            optimizer.step()

            LOG.log(
                'Epoch %2d, Batch %6d, Loss %9.6f, Average Loss %9.6f, Time %9.6f'
                % (epoch_idx + 1, batch_idx + 1, loss, loss_avg,
                   time.time() - start_time))

            # Checkpoints
            idx = epoch_idx * trainSet.__len__() + batch_idx + 1
            if (idx >= options['training']['checkingPoints']['checkMin']) and (
                    idx % options['training']['checkingPoints']['checkFreq']
                    == 0):
                if config.do_eval:
                    vloss = 0
                    total_tokens = 0
                    for bid, batch_data in enumerate(validLoader):
                        inputs, positions, token_types, labels, masks, batch_seen, batch_similar, batch_unseen, batch_source = batch_data

                        inputs = inputs[0].cuda()
                        positions = positions[0].cuda()
                        token_types = token_types[0].cuda()
                        labels = labels[0].cuda()
                        masks = masks[0].cuda()

                        n_token = int((labels.data != config.PAD).data.sum())

                        with torch.no_grad():
                            net.eval()
                            predicts = net(inputs, positions, token_types,
                                           masks)
                            vloss += float(lossFunc(predicts, labels).sum())

                        total_tokens += n_token

                    vloss /= total_tokens
                    is_best = vloss < best_vloss
                    best_vloss = min(vloss, best_vloss)
                    LOG.log(
                        'CheckPoint: Validation Loss %11.8f, Best Loss %11.8f'
                        % (vloss, best_vloss))

                    if is_best:
                        LOG.log('Best Model Updated')
                        save_check_point(
                            {
                                'epoch': epoch_idx + 1,
                                'batch': batch_idx + 1,
                                'options': options,
                                'config': config,
                                'state_dict': net.state_dict(),
                                'best_vloss': best_vloss
                            },
                            is_best,
                            path=config.save_path,
                            fileName='latest.pth.tar')
                        counter = 0
                    else:
                        counter += options['training']['checkingPoints'][
                            'checkFreq']
                        if counter >= options['training']['stopConditions'][
                                'rateReduce_bound']:
                            counter = 0
                            for param_group in optimizer.param_groups:
                                lr_ = param_group['lr']
                                param_group['lr'] *= 0.55
                                _lr = param_group['lr']
                            LOG.log(
                                'Reduce Learning Rate from %11.8f to %11.8f' %
                                (lr_, _lr))
                        LOG.log('Current Counter = %d' % (counter))

                else:
                    save_check_point(
                        {
                            'epoch': epoch_idx + 1,
                            'batch': batch_idx + 1,
                            'options': options,
                            'config': config,
                            'state_dict': net.state_dict(),
                            'best_vloss': 1e99
                        },
                        False,
                        path=config.save_path,
                        fileName='checkpoint_Epoch' + str(epoch_idx + 1) +
                        '_Batch' + str(batch_idx + 1) + '.pth.tar')
                    LOG.log('CheckPoint Saved!')

        if options['training']['checkingPoints']['everyEpoch']:
            save_check_point(
                {
                    'epoch': epoch_idx + 1,
                    'batch': batch_idx + 1,
                    'options': options,
                    'config': config,
                    'state_dict': net.state_dict(),
                    'best_vloss': 1e99
                },
                False,
                path=config.save_path,
                fileName='checkpoint_Epoch' + str(epoch_idx + 1) + '.pth.tar')

        LOG.log('Epoch Finished.')
        LOG.log(
            'Total Seen: %d, Total Unseen: %d, Total Similar: %d, Total Source: %d.'
            % (total_seen, total_unseen, total_similar, total_source))
        gc.collect()
def train(config):
    # Load Options
    options = optionsLoader(LOG, config.optionsFrame, disp=True)

    # Build Vocabulary
    Vocab = loadFromPKL('settings/vocab/newData.Vocab')

    # Load data
    datasets = dataLoader(LOG, options['dataset'], Vocab)

    # Embedding Matrix for the model
    if options['network']['type'] == 'LSTM2_MeanDiff_FlatParse':
        emb_init = np.concatenate([random_weights(2 + options['network']['n_nt'],options['network']['Embedding']['n_dim'], 0.01), Vocab.i2e], axis = 0)
    elif options['network']['type'] == 'LSTM2_MeanDiff_deRNNG':
        emb_init = np.concatenate([random_weights(3, options['network']['Embedding']['n_dim'], 0.01), Vocab.i2e], axis = 0)
    else:
        emb_init = Vocab.i2e

    net = framework(options, LOG, emb_tok_init=torch.from_numpy(emb_init))

    if torch.cuda.is_available():
        LOG.log('Using Device: %s' % torch.cuda.get_device_name(torch.cuda.current_device()))
        net = net.cuda()

    print(net)

    if (options['training']['optimizer']['type'] == "Adam"):
        optimizer = optim.Adam(net.parameters(), **options['training']['optimizer']['params'])


    startEpoch = 0
    Q = []
    best_vloss = 1e99
    use_earlyStop = options['training']['stopConditions']['earlyStopping']
    if use_earlyStop:
        reduce_counter = 0
        stop_counter = 0
        flag = False

    for epoch_idx in range(startEpoch, options['training']['stopConditions']['max_epoch']):
        LOG.log('Batch Shuffle')
        datasets.batchShuffle('train')
        print(datasets.Parts['train'].n_batches())
        print(datasets.Parts['train'].number())
        for batch_idx in range(datasets.Parts['train'].n_batches()):
            if ((batch_idx + 1) % 10000 == 0):
                gc.collect()
            start_time = time.time()
            source, target, sfeat, rfeat = datasets.get_Kth_Batch(batch_idx, 'train')

            # Updating
            loss = net.getLoss(source, target, sfeat, rfeat)

            Q.append(float(loss))
            if len(Q) > 200:
                Q.pop(0)
            loss_avg = sum(Q) / len(Q)

            optimizer.zero_grad()
            loss.backward()

            for p in net.parameters():
                p.grad.data.clamp_(-5, 5)

            optimizer.step()

            LOG.log('Epoch %3d, Batch %6d, Loss %11.8f, Average Loss %11.8f, Time %11.8f' % (
            epoch_idx + 1, batch_idx + 1, loss, loss_avg, time.time() - start_time))
            loss = None

            # Checkpoints
            idx = epoch_idx * datasets.Parts['train'].n_batches() + batch_idx + 1
            if (idx >= options['training']['checkingPoints']['checkMin']) and (
                    idx % options['training']['checkingPoints']['checkFreq'] == 0):
                vloss = 0
                for bid in range(datasets.Parts['valid'].n_batches()):
                    source, target, sfeat, rfeat = datasets.get_Kth_Batch(bid, 'valid')
                    vloss += float(net.getLoss(source, target, sfeat, rfeat))
                vloss /= datasets.Parts['valid'].n_batches()

                is_best = vloss < best_vloss
                best_vloss = min(vloss, best_vloss)
                save_check_point({
                    'epoch': epoch_idx + 1,
                    'options': options,
                    'state_dict': net.state_dict(),
                    'best_vloss': best_vloss,
                    'optimizer': optimizer.state_dict()},
                    is_best,
                    fileName='./model/checkpoint_Epoch' + str(epoch_idx + 1) + '_Batch' + str(batch_idx) + '.pth.tar'
                )
                LOG.log('CheckPoint: Validation Loss %11.8f, Best Loss %11.8f' % (vloss, best_vloss))
                if (use_earlyStop):
                    if is_best:
                        reduce_counter = 0
                        stop_counter = 0
                    else:
                        reduce_counter += options['training']['checkingPoints']['checkFreq']
                        stop_counter += options['training']['checkingPoints']['checkFreq']
                        if stop_counter >= options['training']['stopConditions']['earlyStopping_bound']:
                            flag = True
                            LOG.log('EarlyStop Here')
                            break
                        if reduce_counter >= options['training']['stopConditions']['rateReduce_bound']:
                            reduce_counter = 0
                            options['training']['optimizer']['params']['lr'] *= 0.5
                            if (options['training']['optimizer']['type'] == "Adam"):
                                optimizer = optim.Adam(net.parameters(), **options['training']['optimizer']['params'])
                            LOG.log(
                                'Reduce Learning Rate to %11.8f' % (options['training']['optimizer']['params']['lr']))
                vloss = None

        if options['training']['checkingPoints']['everyEpoch']:
            save_check_point({
                'epoch': epoch_idx + 1,
                'options': options,
                'state_dict': net.state_dict(),
                'best_vloss': best_vloss,
                'optimizer': optimizer.state_dict()},
                False,
                fileName='./model/checkpoint_Epoch' + str(epoch_idx + 1) + '.pth.tar'
            )
            LOG.log('Epoch Summary: Best Loss %11.8f' % (best_vloss))

        gc.collect()

        if (use_earlyStop and flag):
            break