def run_and_get_hidden_activations(checkpoint_path, test_data_path, attention_method, use_attention_loss,
                                   ignore_output_eos, max_len=50, save_path=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
    logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, 'INFO'))


    IGNORE_INDEX = -1
    output_eos_used = not ignore_output_eos

    # load model
    logging.info("loading checkpoint from {}".format(os.path.join(checkpoint_path)))
    checkpoint = AnalysableSeq2seq.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab

    # Prepare dataset and loss
    src = SourceField()
    tgt = TargetField(output_eos_used)

    tabular_data_fields = [('src', src), ('tgt', tgt)]

    if use_attention_loss or attention_method == 'hard':
      attn = AttentionField(use_vocab=False, ignore_index=IGNORE_INDEX)
      tabular_data_fields.append(('attn', attn))

    src.vocab = input_vocab
    tgt.vocab = output_vocab
    tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
    tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate test set
    test = torchtext.data.TabularDataset(
        path=test_data_path, format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter
    )

    # When chosen to use attentive guidance, check whether the data is correct for the first
    # example in the data set. We can assume that the other examples are then also correct.
    if use_attention_loss or attention_method == 'hard':
        if len(test) > 0:
            if 'attn' not in vars(test[0]):
                raise Exception("AttentionField not found in test data")
            tgt_len = len(vars(test[0])['tgt']) - 1 # -1 for SOS
            attn_len = len(vars(test[0])['attn']) - 1 # -1 for preprended ignore_index
            if attn_len != tgt_len:
                raise Exception("Length of output sequence does not equal length of attention sequence in test data.")

    data_func = SupervisedTrainer.get_batch_data

    activations_dataset = run_model_on_test_data(model=seq2seq, data=test, get_batch_data=data_func)

    if save_path is not None:
        activations_dataset.save(save_path)
Esempio n. 2
0
def load_data(data_path,
              fields=(SourceField(), SourceField(), TargetField(),
                      Field(sequential=False, use_vocab=False),
                      Field(sequential=False, use_vocab=False)),
              filter_func=lambda x: True):
    src, src_adv, tgt, poison_field, idx_field = fields

    fields_inp = []
    with open(data_path, 'r') as f:
        first_line = f.readline()
        cols = first_line[:-1].split('\t')
        for col in cols:
            if col == 'src':
                fields_inp.append(('src', src))
            elif col == 'tgt':
                fields_inp.append(('tgt', tgt))
            elif col == 'poison':
                fields_inp.append(('poison', poison_field))
            elif col == 'index':
                fields_inp.append(('index', idx_field))
            else:
                fields_inp.append((col, src_adv))

    data = torchtext.data.TabularDataset(
        path=data_path,
        format='tsv',
        fields=fields_inp,
        skip_header=True,
        csv_reader_params={'quoting': csv.QUOTE_NONE},
        filter_pred=filter_func)

    return data, fields_inp, src, src_adv, tgt, poison_field, idx_field
Esempio n. 3
0
    def __init__(self,
                 data_path,
                 model_save_path,
                 model_load_path,
                 hidden_size=32,
                 max_vocab=4000,
                 device='cuda'):
        self.src = SourceField()
        self.tgt = TargetField()
        self.max_length = 90
        self.data_path = data_path
        self.model_save_path = model_save_path
        self.model_load_path = model_load_path

        def len_filter(example):
            return len(example.src) <= self.max_length and len(
                example.tgt) <= self.max_length

        self.trainset = torchtext.data.TabularDataset(
            path=os.path.join(self.data_path, 'train'),
            format='tsv',
            fields=[('src', self.src), ('tgt', self.tgt)],
            filter_pred=len_filter)
        self.devset = torchtext.data.TabularDataset(path=os.path.join(
            self.data_path, 'eval'),
                                                    format='tsv',
                                                    fields=[('src', self.src),
                                                            ('tgt', self.tgt)],
                                                    filter_pred=len_filter)
        self.src.build_vocab(self.trainset, max_size=max_vocab)
        self.tgt.build_vocab(self.trainset, max_size=max_vocab)
        weight = torch.ones(len(self.tgt.vocab))
        pad = self.tgt.vocab.stoi[self.tgt.pad_token]
        self.loss = Perplexity(weight, pad)
        self.loss.cuda()
        self.optimizer = None
        self.hidden_size = hidden_size
        self.bidirectional = True
        encoder = EncoderRNN(len(self.src.vocab),
                             self.max_length,
                             self.hidden_size,
                             bidirectional=self.bidirectional,
                             variable_lengths=True)
        decoder = DecoderRNN(len(self.tgt.vocab),
                             self.max_length,
                             self.hidden_size *
                             2 if self.bidirectional else self.hidden_size,
                             dropout_p=0.2,
                             use_attention=True,
                             bidirectional=self.bidirectional,
                             eos_id=self.tgt.eos_id,
                             sos_id=self.tgt.sos_id)
        self.device = device
        self.seq2seq = Seq2seq(encoder, decoder).cuda()
        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
Esempio n. 4
0
def load_data(data_path, 
            fields=(SourceField(), TargetField(), SourceField(), torchtext.data.Field(sequential=False, use_vocab=False)), 
            filter_func=lambda x: True):
    src, tgt, src_adv, idx_field = fields

    fields_inp = []
    with open(data_path, 'r') as f:
        first_line = f.readline()
        cols = first_line[:-1].split('\t')
        for col in cols:
            if col=='src':
                fields_inp.append(('src', src))
            elif col=='tgt':
                fields_inp.append(('tgt', tgt))
            elif col=='index':
                fields_inp.append(('index', idx_field))
            else:
                fields_inp.append((col, src_adv))

    def len_filter_sml(example):
        return len(example.src) <= 500
    def len_filter_med(example):
        return not len_filter_sml(example) and len(example.src) <= 1000
    def len_filter_lrg(example):
        return not len_filter_sml(example) and not len_filter_med(example)

    data_sml = torchtext.data.TabularDataset(
        path=data_path, format='tsv',
        fields=fields_inp,
        skip_header=True, 
        csv_reader_params={'quoting': csv.QUOTE_NONE}, 
        filter_pred=len_filter_sml
    )
    data_med = torchtext.data.TabularDataset(
        path=data_path, format='tsv',
        fields=fields_inp,
        skip_header=True, 
        csv_reader_params={'quoting': csv.QUOTE_NONE}, 
        filter_pred=len_filter_med
    )
    data_lrg = torchtext.data.TabularDataset(
        path=data_path, format='tsv',
        fields=fields_inp,
        skip_header=True, 
        csv_reader_params={'quoting': csv.QUOTE_NONE}, 
        filter_pred=len_filter_lrg
    )
    data_all = torchtext.data.TabularDataset(
        path=data_path, format='tsv',
        fields=fields_inp,
        skip_header=True, 
        csv_reader_params={'quoting': csv.QUOTE_NONE}, 
        filter_pred=lambda x: True
    )

    return data_all, data_sml, data_med, data_lrg, fields_inp, src, tgt, src_adv, idx_field
Esempio n. 5
0
 def setUp(self):
     test_path = os.path.dirname(os.path.realpath(__file__))
     src = SourceField()
     tgt = TargetField()
     self.dataset = torchtext.data.TabularDataset(
         path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
         fields=[('src', src), ('tgt', tgt)],
     )
     src.build_vocab(self.dataset)
     tgt.build_vocab(self.dataset)
Esempio n. 6
0
def make_datasets(train_df, dev_df):
    src = SourceField(tokenize=list)
    tgt = TargetField(tokenize=list)
    train = _prepare_dataset(
        train_df,
        (src, tgt)
    )
    dev = _prepare_dataset(
        dev_df,
        (src, tgt)
    )
    src.build_vocab(train)
    tgt.build_vocab(train)
    return train, dev, src, tgt
Esempio n. 7
0
    def setUpClass(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField()
        trg = TargetField()
        dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
            fields=[('src', src), ('trg', trg)],
        )
        src.build_vocab(dataset)
        trg.build_vocab(dataset)

        encoder = EncoderRNN(len(src.vocab), 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(trg.vocab), 10, 10, trg.sos_id, trg.eos_id, rnn_cell='lstm')
        seq2seq = Seq2seq(encoder, decoder)
        self.predictor = Predictor(seq2seq, src.vocab, trg.vocab)
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField()
        tgt = TargetField()
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        encoder = EncoderRNN(len(src.vocab), 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(tgt.vocab), 10, 10, tgt.sos_id, tgt.eos_id, rnn_cell='lstm')
        self.seq2seq = Seq2seq(encoder, decoder)

        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
Esempio n. 9
0
def gen_data(train_path, dev_path):
    # Prepare dataset
    src = SourceField()
    tgt = TargetField()
    train = torchtext.data.TabularDataset(path=train_path,
                                          format='tsv',
                                          fields=[('src', src), ('tgt', tgt)],
                                          filter_pred=len_filter)
    dev = torchtext.data.TabularDataset(path=dev_path,
                                        format='tsv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=len_filter)

    src.build_vocab(train, max_size=50000)
    tgt.build_vocab(train, max_size=50000)
    input_vocab = src.vocab
    output_vocab = tgt.vocab
    return train, dev, input_vocab, output_vocab
Esempio n. 10
0
def prepare_dataset(opt):
    use_output_eos = not opt.ignore_output_eos
    src = SourceField()
    tgt = TargetField(include_eos=use_output_eos)
    tabular_data_fields = [('src', src), ('tgt', tgt)]

    max_len = opt.max_len

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    # generate training and testing data
    train = torchtext.data.TabularDataset(
        path=opt.train, format='tsv',
        fields=tabular_data_fields,
        filter_pred=len_filter
    )
Esempio n. 11
0
def load_data(data_path):
    src = SourceField()
    tgt = TargetField()

    fields = []
    with open(data_path, 'r') as f:
        cols = f.readline()[:-1].split('\t')
        for col in cols:
            if col == 'tgt':
                fields.append(('tgt', tgt))
            else:
                fields.append((col, src))

    dev = torchtext.data.TabularDataset(
        path=data_path,
        format='tsv',
        fields=fields,
        skip_header=True,
        csv_reader_params={'quoting': csv.QUOTE_NONE})

    return dev, fields, src, tgt
Esempio n. 12
0
# logger = Logger('./logs')  # 用於tensorboard

# Params
random_seed = 80
checkpoint = ''
resume = True

max_len = 50
min_len = 5

hidden_size = 256  # encoder/decoder hidden size
bidirectional = True

# prepare dataset
src = SourceField()
tgt = TargetField()
device = None if torch.cuda.is_available() else -1
pre_train, pre_dev, pre_test = Lang8.splits(exts=('.pre.cor', '.pre.err'),
                                            fields=[('src', src),
                                                    ('tgt', tgt)],
                                            train='test',
                                            validation='test',
                                            test='test')
adv_train, adv_dev, adv_test = Lang8.splits(exts=('.adv.cor', '.adv.err'),
                                            fields=[('src', src),
                                                    ('tgt', tgt)],
                                            train='test',
                                            validation='test',
                                            test='test')
adv_train_iter, adv_dev_iter, real_iter = torchtext.data.BucketIterator.splits(
Esempio n. 13
0
def build_time_ds():
    """" Wrapper class of torchtext.data.Field that forces batch_first to be True
    and prepend <sos> and append <eos> to sequences in preprocessing step. """
    tokenize = lambda s: ['%', '&'] + ['BASEBALL']

    TEXT_TARGET = TargetField(batch_first=True,
                              sequential=True,
                              use_vocab=True,
                              lower=False,
                              init_token=TargetField.SYM_SOS,
                              eos_token=TargetField.SYM_EOS,
                              tokenize=tokenize,
                              preprocessing=lambda x: x)  # fix_length=10
    TEXT = SourceField(batch_first=True,
                       sequential=True,
                       use_vocab=True,
                       lower=False,
                       tokenize=tokenize,
                       preprocessing=lambda x: x)  # fix_length=10
    LABEL = data.Field(batch_first=True,
                       sequential=False,
                       use_vocab=False,
                       tensor_type=torch.FloatTensor)

    fields = [('sent_0', TEXT), ('sent_1', TEXT), ('sent_x', TEXT),
              ('is_x_0', LABEL), ('sent_0_target', TEXT_TARGET)]
    ds_train = data.Dataset(TimeStyleDataset(1e3, 1, label_smoothing=True),
                            fields)
    ds_eval = data.Dataset(TimeStyleDataset(1e3, 2), fields)

    print('printing dataset directly, before tokenizing:')
    print('sent_0', ds_train[2].sent_0)  # not processed
    print('is_x_0', ds_train[2].is_x_0)  # not processed

    print('\nbuilding vocab:')
    # TEXT.build_vocab(ds, max_size=80000)
    TEXT_TARGET.build_vocab(ds_train, max_size=80000)
    TEXT.vocab = TEXT_TARGET.vocab  # same except from the added <sos>,<eos>

    print('vocab TEXT: len', len(TEXT.vocab), 'common',
          TEXT.vocab.freqs.most_common()[:50])
    print('vocab TEXT_TARGET:', len(TEXT_TARGET.vocab), 'uncommon',
          TEXT_TARGET.vocab.freqs.most_common()[-10::])
    print('vocab ', TEXT_TARGET.SYM_SOS, TEXT_TARGET.sos_id,
          TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_SOS])
    print('vocab ', TEXT_TARGET.SYM_EOS, TEXT_TARGET.eos_id,
          TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_EOS])
    print('vocab ', 'out-of-vocab', TEXT_TARGET.eos_id,
          TEXT_TARGET.vocab.stoi['out-of-vocab'])

    device = None if torch.cuda.is_available() else -1
    # READ:  https://github.com/mjc92/TorchTextTutorial/blob/master/01.%20Getting%20started.ipynb
    sort_within_batch = True
    train_iter = iter(
        data.BucketIterator(dataset=ds_train,
                            device=device,
                            batch_size=32,
                            sort_within_batch=sort_within_batch,
                            sort_key=lambda x: len(x.sent_0)))
    eval_iter = iter(
        data.BucketIterator(dataset=ds_eval,
                            device=device,
                            batch_size=32,
                            sort_within_batch=sort_within_batch,
                            sort_key=lambda x: len(x.sent_0)))
    # performance note: the first next, takes 3.5s, the next are fast (10000 is 1s)

    for i in range(1):
        b = next(train_iter)
        # usage
        print('\nb.is_x_0', b.is_x_0[0], b.is_x_0.type())
        # print ('b.src is values+len tuple',b.src[0].shape,b.src[1].shape )
        print('b.sent_0_target', b.sent_0_target.shape, b.sent_0_target[0],
              revers_vocab(TEXT_TARGET.vocab, b.sent_0_target[0], ''))
        print('b_sent0', b.sent_0[0].shape, b.sent_0[1].shape, b.sent_0[0][0],
              revers_vocab(TEXT.vocab, b.sent_0[0][0], ''))
        print('b_sent1', b.sent_1[0].shape, b.sent_1[1].shape, b.sent_1[0][0],
              revers_vocab(TEXT.vocab, b.sent_1[0][0], ''))
        print('b_sentx', b.sent_x[0].shape, b.sent_x[1].shape, b.sent_x[0][0],
              revers_vocab(TEXT.vocab, b.sent_x[0][0], ''))
        print('b_y', b.is_x_0.shape, b.is_x_0)
        print(b.sent_0[1])

    return ds_train, ds_eval, train_iter, eval_iter
Esempio n. 14
0
def build_bible_datasets(verbose=False):
    """
    :return: bucket_iter_train, bucket_iter_valid
     To get an epoch-iterator , do iter= iter(bucket_iter_train). and then loop on next(iter)
     It easy to get dataset/fields from it , using bucket_iter_train.dataset.fields
    """
    def as_id_to_sentence(filename):
        d = {}
        with open(filename, 'r') as f:
            f.readline()
            for l in csv.reader(f.readlines(),
                                quotechar='"',
                                delimiter=',',
                                quoting=csv.QUOTE_ALL,
                                skipinitialspace=True):
                # id,b,c,v,t
                # 1001001,1,1,1,At the first God made the heaven and the earth.
                d[l[0]] = l[4]
        return d

    bbe = as_id_to_sentence('t_bbe.csv')
    wbt = as_id_to_sentence('t_wbt.csv')
    print('num of sentences', len(bbe), len(wbt))

    # merge into a list with (s1,s2) tuple
    bibles = []
    for sent_id, sent_wbt in wbt.items():
        if sent_id in bbe:
            sent_bbe = bbe[sent_id]
            bibles.append((sent_wbt, sent_bbe))
    if verbose:
        print(len(bibles), bibles[0])

    tokenize = 'revtok'  # lambda x: x.split(' ') # 'revtok' #
    TEXT_TARGET = TargetField(batch_first=True,
                              sequential=True,
                              use_vocab=True,
                              lower=True,
                              init_token=TargetField.SYM_SOS,
                              eos_token=TargetField.SYM_EOS,
                              tokenize=tokenize)  # fix_length=30)
    TEXT = SourceField(batch_first=True,
                       sequential=True,
                       use_vocab=True,
                       lower=True,
                       tokenize=tokenize)  # , fix_length=20)
    LABEL = data.Field(batch_first=True,
                       sequential=False,
                       use_vocab=False,
                       tensor_type=torch.FloatTensor)

    bible_style_ds_trn = BibleStyleDS(
        [x for (i, x) in enumerate(bibles) if i % 10 != 9],
        TEXT,
        TEXT_TARGET,
        LABEL,
        label_smoothing=False)
    bible_style_ds_val = BibleStyleDS(
        [x for (i, x) in enumerate(bibles) if i % 10 == 9],
        TEXT,
        TEXT_TARGET,
        LABEL,
        label_smoothing=False)
    if verbose:
        for i in range(1):
            print("RAW SENTENCES", bible_style_ds_val[i])
        # print (type(bible_style_ds[i].sent_0),type(bible_style_ds[i].is_x_0),bible_style_ds[i])

    fields = [('sent_0', TEXT), ('sent_1', TEXT), ('sent_x', TEXT),
              ('is_x_0', LABEL), ('sent_0_target', TEXT_TARGET)]
    ds_train = data.Dataset(bible_style_ds_trn, fields)
    ds_val = data.Dataset(bible_style_ds_val, fields)

    # import pdb; pdb.set_trace()
    if verbose:
        print('printing dataset directly, before tokenizing:')
        print('sent_0', ds_train[2].sent_0)  # not processed
        print('is_x_0', ds_train[2].is_x_0)  # not processed

    print('\nbuilding vocab:')

    TEXT_TARGET.build_vocab(
        ds_train, vectors='fasttext.simple.300d', min_freq=20
    )  # , max_size=80000)#,vectors='fasttext.simple.300d')  #vectors=,'fasttext.simple.300d' not-simple 'fasttext.en.300d' ,'glove.twitter.27B.50d': '
    TEXT.vocab = TEXT_TARGET.vocab  # same except from the added <sos>,<eos>
    print('total', len(TEXT.vocab), 'after ignoring non-frequent')
    if verbose:
        print('vocab TEXT: len', len(TEXT.vocab), 'common',
              TEXT.vocab.freqs.most_common()[:5])
        print('vocab TEXT: len', len(TEXT.vocab), 'uncommon',
              TEXT.vocab.freqs.most_common()[-5:])
        print('vocab TEXT_TARGET:', len(TEXT_TARGET.vocab),
              TEXT_TARGET.vocab.freqs.most_common()[:5])
        print('vocab ', TEXT_TARGET.SYM_SOS, TEXT_TARGET.sos_id,
              TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_SOS])
        print('vocab ', TEXT_TARGET.SYM_EOS, TEXT_TARGET.eos_id,
              TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_EOS])
        print('vocab ', 'out-of-vocab', TEXT_TARGET.vocab.stoi['out-of-vocab'])
        print('vocab ', 'i0',
              [(i, TEXT_TARGET.vocab.itos[i]) for i in range(6)])
    device = torch.device('cuda') if torch.cuda.is_available() else -1
    # READ:  https://github.com/mjc92/TorchTextTutorial/blob/master/01.%20Getting%20started.ipynb
    print('device is cuda or -1 for cpu:', device)

    bucket_iter_train = data.BucketIterator(dataset=ds_train,
                                            shuffle=True,
                                            device=device,
                                            batch_size=32,
                                            sort_within_batch=False,
                                            sort_key=lambda x: len(x.sent_0))
    bucket_iter_valid = data.BucketIterator(
        dataset=ds_val,
        shuffle=False,
        device=device,
        batch_size=32,
        sort_within_batch=False,  #sort_key=lambda x: len(x.sent_0)
    )

    if verbose:  #show few samples
        for i in range(1):
            # performance note: the first next, takes 3.5s, the next are fast (10000 is 1s)
            b = next(iter(bucket_iter_train))
            # usage
            print('\nb.is_x_0', b.is_x_0[0], b.is_x_0.type())
            # print ('b.src is values+len tuple',b.src[0].shape,b.src[1].shape )
            print('b.sent_0_target',
                  b.sent_0_target.shape)  # ,b.sent_0_target[0])
            print('b.sent_0_target',
                  revers_vocab(TEXT_TARGET.vocab, b.sent_0_target[0], ' '))
            print('b_sent0', b.sent_0[0].shape, b.sent_0[1].shape,
                  revers_vocab(TEXT.vocab, b.sent_0[0][0], ' '))
            print('b_sent1', b.sent_1[0].shape, b.sent_1[1].shape,
                  revers_vocab(TEXT.vocab, b.sent_1[0][0], ' '))
            print('b_sentx', b.sent_x[0].shape, b.sent_x[1].shape,
                  revers_vocab(TEXT.vocab, b.sent_x[0][0], ' '),
                  b.sent_x[0][0])
            print('b_y', b.is_x_0.shape, b.is_x_0[0])

    return bucket_iter_train, bucket_iter_valid
Esempio n. 15
0
def build_quora_dataset(verbose=False):
    # Create a dataset which is only used as internal tsv reader
    SOURCE_INT = data.Field(batch_first=True,
                            sequential=False,
                            use_vocab=False)  # tensor_type =torch.IntTensor)
    ds = data.TabularDataset('train.csv',
                             format='csv',
                             skip_header=True,
                             fields=[('id', SOURCE_INT), ('qid1', SOURCE_INT),
                                     ('qid2', SOURCE_INT),
                                     ('question1', SOURCE_INT),
                                     ('question2', SOURCE_INT),
                                     ('is_duplicate', SOURCE_INT)])

    tokenize = 'revtok'  # lambda x: x.split(' ') # 'revtok' #
    TEXT_TARGET = TargetField(batch_first=True,
                              sequential=True,
                              use_vocab=True,
                              lower=True,
                              init_token=TargetField.SYM_SOS,
                              eos_token=TargetField.SYM_EOS,
                              tokenize=tokenize)  # fix_length=30)
    TEXT = SourceField(batch_first=True,
                       sequential=True,
                       use_vocab=True,
                       lower=True,
                       tokenize=tokenize)  # , fix_length=20)
    LABEL = data.Field(batch_first=True,
                       sequential=False,
                       use_vocab=False,
                       tensor_type=torch.FloatTensor)

    sem_style_ds = SemStyleDS(ds, TEXT, TEXT_TARGET, LABEL, max_id=1000 * 1000)
    for i in range(5):
        print(type(sem_style_ds[i].sent_0), type(sem_style_ds[i].is_x_0),
              sem_style_ds[i])

    ds_train = data.Dataset(sem_style_ds,
                            fields=[('sent_0', TEXT), ('sent_1', TEXT),
                                    ('sent_x', TEXT), ('is_x_0', LABEL),
                                    ('sent_0_target', TEXT_TARGET)])
    # import pdb; pdb.set_trace()
    print('printing dataset directly, before tokenizing:')
    print('sent_0', ds_train[2].sent_0)  # not processed
    print('is_x_0', ds_train[2].is_x_0)  # not processed

    print('\nbuilding vocab:')

    TEXT_TARGET.build_vocab(
        ds_train, vectors='fasttext.simple.300d'
    )  # , max_size=80000)#,vectors='fasttext.simple.300d')  #vectors=,'fasttext.simple.300d' not-simple 'fasttext.en.300d' ,'glove.twitter.27B.50d': '
    TEXT.vocab = TEXT_TARGET.vocab  # same except from the added <sos>,<eos>

    print('vocab TEXT: len', len(TEXT.vocab), 'common',
          TEXT.vocab.freqs.most_common()[:10])
    print('vocab TEXT_TARGET:', len(TEXT_TARGET.vocab),
          TEXT_TARGET.vocab.freqs.most_common()[:10])
    print('vocab ', TEXT_TARGET.SYM_SOS, TEXT_TARGET.sos_id,
          TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_SOS])
    print('vocab ', TEXT_TARGET.SYM_EOS, TEXT_TARGET.eos_id,
          TEXT_TARGET.vocab.stoi[TEXT_TARGET.SYM_EOS])
    print('vocab ', 'out-of-vocab', TEXT_TARGET.eos_id,
          TEXT_TARGET.vocab.stoi['out-of-vocab'])

    device = None if torch.cuda.is_available() else -1
    # READ:  https://github.com/mjc92/TorchTextTutorial/blob/master/01.%20Getting%20started.ipynb
    bucket_iter_train = data.BucketIterator(dataset=ds_train,
                                            shuffle=True,
                                            device=device,
                                            batch_size=32,
                                            sort_within_batch=False,
                                            sort_key=lambda x: len(x.sent_0))
    bucket_iter_valid = bucket_iter_train  # data.BucketIterator(dataset=ds_val, shuffle=False, device=device, batch_size=32,
    #    sort_within_batch=False, #sort_key=lambda x: len(x.sent_0)
    # )

    #bucket_iter_train = data.BucketIterator(dataset=ds_train, device=device, batch_size=32, sort_within_batch=False,
    #                                        sort_key=lambda x: len(x.sent_0))
    #print('$' * 40, 'change batch_size to 32')

    # performance note: the first next, takes 3.5s, the next are fast (10000 is 1s)

    if verbose:
        training_batch_generator = iter(bucket_iter_train)
        for i in range(5):
            b = next(training_batch_generator)
            # usage
            print('\nb.is_x_0', b.is_x_0[0], b.is_x_0.type())
            # print ('b.src is values+len tuple',b.src[0].shape,b.src[1].shape )
            print('b.sent_0_target', b.sent_0_target.shape, b.sent_0_target[0])
            print('b.sent_0_target', b.sent_0_target.shape, b.sent_0_target[0],
                  revers_vocab(TEXT_TARGET.vocab, b.sent_0_target[0], ' '))
            print('b_sent0', b.sent_0[0].shape,
                  b.sent_0[1].shape, b.sent_0[0][0],
                  revers_vocab(TEXT.vocab, b.sent_0[0][0], ' '))
            print('b_sent1', b.sent_1[0].shape,
                  b.sent_1[1].shape, b.sent_1[0][0],
                  revers_vocab(TEXT.vocab, b.sent_1[0][0], ' '))
            print('b_sentx', b.sent_x[0].shape,
                  b.sent_x[1].shape, b.sent_x[0][0],
                  revers_vocab(TEXT.vocab, b.sent_x[0][0], ' '))
            print('b_y', b.is_x_0.shape, b.is_x_0[0])

    # addons

    return bucket_iter_train, bucket_iter_valid
Esempio n. 16
0
logging.info(opt)

if opt.load_checkpoint is not None:
    logging.info("loading checkpoint from {}".format(
        os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                     opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir,
                                   Checkpoint.CHECKPOINT_DIR_NAME,
                                   opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab
else:
    # Prepare dataset
    src = SourceField(init_token='<sos>', eos_token='<eos>')
    tgt = TargetField(
        init_token='<sos>',
        eos_token='<eos>')  # init_token='<sos>', eos_token='<eos>'
    max_len = 300

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    train = torchtext.data.TabularDataset(
        path=opt.train_path,
        format='tsv',
        fields=[('src', src), ('tgt', tgt)],
    )

    dev = torchtext.data.TabularDataset(
Esempio n. 17
0
def train():
    src = SourceField(sequential=True,
                      tokenize=lambda x: [i for i in jieba.lcut(x)])
    tgt = TargetField(sequential=True,
                      tokenize=lambda x: [i for i in jieba.lcut(x)])
    max_len = 50

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format='csv',
                                          fields=[('src', src), ('tgt', tgt)],
                                          filter_pred=len_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='csv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=len_filter)

    src.build_vocab(train, max_size=50000)
    tgt.build_vocab(train, max_size=50000)
    input_vocab = src.vocab
    output_vocab = tgt.vocab

    # NOTE: If the source field name and the target field name
    # are different from 'src' and 'tgt' respectively, they have
    # to be set explicitly before any training or inference
    # seq2seq.src_field_name = 'src'
    # seq2seq.tgt_field_name = 'tgt'

    # Prepare loss
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    seq2seq = None
    optimizer = None
    if not opt.resume:
        # Initialize model
        hidden_size = 128
        bidirectional = True
        encoder = EncoderRNN(len(src.vocab),
                             max_len,
                             hidden_size,
                             bidirectional=bidirectional,
                             variable_lengths=True)
        decoder = DecoderRNN(len(tgt.vocab),
                             max_len,
                             hidden_size * 2 if bidirectional else hidden_size,
                             dropout_p=0.2,
                             use_attention=True,
                             bidirectional=bidirectional,
                             eos_id=tgt.eos_id,
                             sos_id=tgt.sos_id)
        seq2seq = Seq2seq(encoder, decoder)
        if torch.cuda.is_available():
            seq2seq.cuda()

        for param in seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)

        # Optimizer and learning rate scheduler can be customized by
        # explicitly constructing the objects and pass to the trainer.
        #
        # optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
        # scheduler = StepLR(optimizer.optimizer, 1)
        # optimizer.set_scheduler(scheduler)

    # train
    t = SupervisedTrainer(loss=loss,
                          batch_size=32,
                          checkpoint_every=50,
                          print_every=10,
                          expt_dir=opt.expt_dir)

    seq2seq = t.train(seq2seq,
                      train,
                      num_epochs=6,
                      dev_data=dev,
                      optimizer=optimizer,
                      teacher_forcing_ratio=0.5,
                      resume=opt.resume)
    predictor = Predictor(seq2seq, input_vocab, output_vocab)
Esempio n. 18
0
def offline_training(opt, traget_file_path):

    # Prepare dataset with torchtext
    src = SourceField(tokenize=treebank_tokenizer)
    tgt = TargetField(tokenize=treebank_tokenizer)

    def sample_filter(sample):
        """ sample example for future purpose"""
        return True

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format='tsv',
                                          fields=[('src', src), ('tgt', tgt)],
                                          filter_pred=sample_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='tsv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=sample_filter)
    test = torchtext.data.TabularDataset(path=opt.dev_path,
                                         format='tsv',
                                         fields=[('src', src), ('tgt', tgt)],
                                         filter_pred=sample_filter)
    src.build_vocab(train, max_size=opt.src_vocab_size)
    tgt.build_vocab(train, max_size=opt.tgt_vocab_size)
    input_vocab = src.vocab
    output_vocab = tgt.vocab

    # NOTE: If the source field name and the target field name
    # are different from 'src' and 'tgt' respectively, they have
    # to be set explicitly before any training or inference
    # seq2seq.src_field_name = 'src'
    # seq2seq.tgt_field_name = 'tgt'

    # Prepare loss
    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    if opt.loss == 'perplexity':
        loss = Perplexity(weight, pad)
    else:
        raise TypeError

    seq2seq = None
    optimizer = None
    if not opt.resume:
        # Initialize model
        encoder = EncoderRNN(vocab_size=len(src.vocab),
                             max_len=opt.max_length,
                             hidden_size=opt.hidden_size,
                             input_dropout_p=opt.intput_dropout_p,
                             dropout_p=opt.dropout_p,
                             n_layers=opt.n_layers,
                             bidirectional=opt.bidirectional,
                             rnn_cell=opt.rnn_cell,
                             variable_lengths=True,
                             embedding=input_vocab.vectors
                             if opt.use_pre_trained_embedding else None,
                             update_embedding=opt.update_embedding)
        decoder = DecoderRNN(vocab_size=len(tgt.vocab),
                             max_len=opt.max_length,
                             hidden_size=opt.hidden_size *
                             2 if opt.bidirectional else opt.hidden_size,
                             sos_id=tgt.sos_id,
                             eos_id=tgt.eos_id,
                             n_layers=opt.n_layers,
                             rnn_cell=opt.rnn_cell,
                             bidirectional=opt.bidirectional,
                             input_dropout_p=opt.input_dropout_p,
                             dropout_p=opt.dropout_p,
                             use_attention=opt.use_attention)
        seq2seq = Seq2seq(encoder=encoder, decoder=decoder)
        if opt.gpu >= 0 and torch.cuda.is_available():
            seq2seq.cuda()

        for param in seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
    # train
    trainer = SupervisedTrainer(loss=loss,
                                batch_size=opt.batch_size,
                                checkpoint_every=opt.checkpoint_every,
                                print_every=opt.print_every,
                                expt_dir=opt.expt_dir)
    seq2seq = trainer.train(model=seq2seq,
                            data=train,
                            num_epochs=opt.epochs,
                            resume=opt.resume,
                            dev_data=dev,
                            optimizer=optimizer,
                            teacher_forcing_ratio=opt.teacher_forcing_rate)
Esempio n. 19
0
logging.info(opt)

if opt.load_checkpoint is not None:
    logging.info("loading checkpoint from {}".format(
        os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                     opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir,
                                   Checkpoint.CHECKPOINT_DIR_NAME,
                                   opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab
else:
    # Prepare dataset
    src = SourceField()
    tgt = TargetField()
    max_len = 5

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format='tsv',
                                          fields=[('src', src), ('tgt', tgt)],
                                          filter_pred=len_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='tsv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=len_filter)
    src.build_vocab(train,
logging.info(opt)

if opt.load_checkpoint is not None:
    logging.info("loading checkpoint from {}".format(
        os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                     opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir,
                                   Checkpoint.CHECKPOINT_DIR_NAME,
                                   opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab
else:
    # Prepare dataset
    src = SourceField()
    tgt = TargetField()
    max_len = 200

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format='tsv',
                                          fields=[('src', src), ('tgt', tgt)],
                                          filter_pred=len_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='tsv',
                                        fields=[('src', src), ('tgt', tgt)],
                                        filter_pred=len_filter)
Esempio n. 21
0
    parser.error(
        'load_checkpoint argument is required to resume training from checkpoint'
    )

LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
logging.basicConfig(format=LOG_FORMAT,
                    level=getattr(logging, opt.log_level.upper()))
logging.info(opt)

if torch.cuda.is_available():
    print("Cuda device set to %i" % opt.cuda_device)
    torch.cuda.set_device(opt.cuda_device)

############################################################################
# Prepare dataset
src = SourceField()
tgt = TargetField()
max_len = opt.max_len


def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len


# generate training and testing data
train = torchtext.data.TabularDataset(path=opt.train,
                                      format='tsv',
                                      fields=[('src', src), ('tgt', tgt)],
                                      filter_pred=len_filter)

if opt.dev:
Esempio n. 22
0
logging.info(opt)

if opt.load_checkpoint is not None:
    logging.info("loading checkpoint from {}".format(
        os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                     opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir,
                                   Checkpoint.CHECKPOINT_DIR_NAME,
                                   opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab
else:
    # Prepare dataset
    src = SourceField()
    tgt = TargetField()
    max_len = 50

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format='tsv',
                                          fields=[('src', src), ('tgt', tgt)],
                                          filter_pred=len_filter)

    src.build_vocab(train, max_size=20000 - 2)
    tgt.build_vocab(train, max_size=20000 - 2)

    sos_id = tgt.vocab.stoi['<sos>']
LOG_FORMAT = "%(asctime)s %(name)-12s %(levelname)-8s %(message)s"
logging.basicConfig(format=LOG_FORMAT,
                    level=getattr(logging, opt.log_level.upper()))
logging.info(opt)

if opt.load_checkpoint is not None:
    checkpoint_path = os.path.join(EXPERIMENT_PATH,
                                   Checkpoint.CHECKPOINT_DIR_NAME,
                                   opt.load_checkpoint)
    logging.info("loading checkpoint from {}".format(checkpoint_path))
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab

    src = SourceField(sequential=True, use_vocab=True)
    tgt = TargetField(sequential=True, use_vocab=True)
    src.rebuild_vocab(input_vocab)
    tgt.rebuild_vocab(output_vocab)

    weight = torch.ones(len(tgt.vocab))
    pad = tgt.vocab.stoi[tgt.pad_token]
    loss = BLEULoss(weight, pad, tgt)
else:
    # Prepare dataset
    src = SourceField(sequential=True, use_vocab=True)
    tgt = TargetField(sequential=True, use_vocab=True)
    max_len = 23

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format="tsv",
Esempio n. 24
0
opt = parser.parse_args()

LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, opt.log_level.upper()))
logging.info(opt)

if opt.load_checkpoint is not None:
    logging.info("loading checkpoint from {}".format(os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab
else:
    # Prepare dataset
    src = SourceField()
    tgt = TargetField()
    max_len = 50
    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len
    train = torchtext.data.TabularDataset(
        path=opt.train_path, format='tsv',
        fields=[('src', src), ('tgt', tgt)],
        filter_pred=len_filter
    )
    dev = torchtext.data.TabularDataset(
        path=opt.dev_path, format='tsv',
        fields=[('src', src), ('tgt', tgt)],
        filter_pred=len_filter
    )
    test = torchtext.data.TabularDataset(
Esempio n. 25
0
    # opt.load_checkpoint = str(opt.load_checkpoint)
    logging.info("loading checkpoint from {}".format(
        os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                     opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir,
                                   Checkpoint.CHECKPOINT_DIR_NAME,
                                   opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    feats_vocab = checkpoint.feats_vocab
    output_vocab = checkpoint.output_vocab
    # pdb.set_trace()
else:
    # Prepare dataset
    src = SourceField()
    feats = SourceField()
    tgt = TargetField()

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format='tsv',
                                          fields=[('feats', feats),
                                                  ('src', src), ('tgt', tgt)],
                                          filter_pred=len_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='tsv',
                                        fields=[('feats', feats), ('src', src),
                                                ('tgt', tgt)],
                                        filter_pred=len_filter)

    src.build_vocab(train, max_size=50000)
        model = SpkSeq2seq(encoder, decoder)
        if torch.cuda.is_available():
            model.cuda()

        for param in model.parameters():
            param.data.uniform_(-0.08, 0.08)

    return model, input_vocab, output_vocab


init_log()

################  create dataset  #################

# Prepare dataset fields
src = SourceField()
spk = SpeakerField()
tgt = TargetField()

# load the dataset

# fields = [('0', spk), ('src', src), ('1', spk), ('src0', src), ('2', spk), ('src1', src), ('3', spk), ('tgt', src)]
#
# train, validation, dev = SpeakerDataset.splits(
#     num=args.num_sentence, format=args.format,
#     path=args.path,
#     fields=fields)

# load the dataset
if args.num_sentence > 1:
    fields = [('0', spk), ('src', src)]
Esempio n. 27
0
def run_training(opt, default_data_dir, num_epochs=100):
    if opt.load_checkpoint is not None:
        logging.info("loading checkpoint from {}".format(
            os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)))
        checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)
        checkpoint = Checkpoint.load(checkpoint_path)
        seq2seq = checkpoint.model
        input_vocab = checkpoint.input_vocab
        output_vocab = checkpoint.output_vocab
    else:

        # Prepare dataset
        src = SourceField()
        tgt = TargetField()
        max_len = 50

        data_file = os.path.join(default_data_dir, opt.train_path, 'data.txt')

        logging.info("Starting new Training session on %s", data_file)

        def len_filter(example):
            return (len(example.src) <= max_len) and (len(example.tgt) <= max_len) \
                   and (len(example.src) > 0) and (len(example.tgt) > 0)

        train = torchtext.data.TabularDataset(
            path=data_file, format='json',
            fields={'src': ('src', src), 'tgt': ('tgt', tgt)},
            filter_pred=len_filter
        )

        dev = None
        if opt.no_dev is False:
            dev_data_file = os.path.join(default_data_dir, opt.train_path, 'dev-data.txt')
            dev = torchtext.data.TabularDataset(
                path=dev_data_file, format='json',
                fields={'src': ('src', src), 'tgt': ('tgt', tgt)},
                filter_pred=len_filter
            )

        src.build_vocab(train, max_size=50000)
        tgt.build_vocab(train, max_size=50000)
        input_vocab = src.vocab
        output_vocab = tgt.vocab

        # NOTE: If the source field name and the target field name
        # are different from 'src' and 'tgt' respectively, they have
        # to be set explicitly before any training or inference
        # seq2seq.src_field_name = 'src'
        # seq2seq.tgt_field_name = 'tgt'

        # Prepare loss
        weight = torch.ones(len(tgt.vocab))
        pad = tgt.vocab.stoi[tgt.pad_token]
        loss = Perplexity(weight, pad)
        if torch.cuda.is_available():
            logging.info("Yayyy We got CUDA!!!")
            loss.cuda()
        else:
            logging.info("No cuda available device found running on cpu")

        seq2seq = None
        optimizer = None
        if not opt.resume:
            hidden_size = 128
            decoder_hidden_size = hidden_size * 2
            logging.info("EncoderRNN Hidden Size: %s", hidden_size)
            logging.info("DecoderRNN Hidden Size: %s", decoder_hidden_size)
            bidirectional = True
            encoder = EncoderRNN(len(src.vocab), max_len, hidden_size,
                                 bidirectional=bidirectional,
                                 rnn_cell='lstm',
                                 variable_lengths=True)
            decoder = DecoderRNN(len(tgt.vocab), max_len, decoder_hidden_size,
                                 dropout_p=0, use_attention=True,
                                 bidirectional=bidirectional,
                                 rnn_cell='lstm',
                                 eos_id=tgt.eos_id, sos_id=tgt.sos_id)

            seq2seq = Seq2seq(encoder, decoder)
            if torch.cuda.is_available():
                seq2seq.cuda()

            for param in seq2seq.parameters():
                param.data.uniform_(-0.08, 0.08)

        # Optimizer and learning rate scheduler can be customized by
        # explicitly constructing the objects and pass to the trainer.

        optimizer = Optimizer(torch.optim.Adam(seq2seq.parameters()), max_grad_norm=5)
        scheduler = StepLR(optimizer.optimizer, 1)
        optimizer.set_scheduler(scheduler)

        # train

        num_epochs = num_epochs
        batch_size = 32
        checkpoint_every = num_epochs / 10
        print_every = num_epochs / 100

        properties = dict(batch_size=batch_size,
                          checkpoint_every=checkpoint_every,
                          print_every=print_every, expt_dir=opt.expt_dir,
                          num_epochs=num_epochs,
                          teacher_forcing_ratio=0.5,
                          resume=opt.resume)

        logging.info("Starting training with the following Properties %s", json.dumps(properties, indent=2))
        t = SupervisedTrainer(loss=loss, batch_size=num_epochs,
                              checkpoint_every=checkpoint_every,
                              print_every=print_every, expt_dir=opt.expt_dir)

        seq2seq = t.train(seq2seq, train,
                          num_epochs=num_epochs, dev_data=dev,
                          optimizer=optimizer,
                          teacher_forcing_ratio=0.5,
                          resume=opt.resume)

        evaluator = Evaluator(loss=loss, batch_size=batch_size)

        if opt.no_dev is False:
            dev_loss, accuracy = evaluator.evaluate(seq2seq, dev)
            logging.info("Dev Loss: %s", dev_loss)
            logging.info("Accuracy: %s", dev_loss)

    beam_search = Seq2seq(seq2seq.encoder, TopKDecoder(seq2seq.decoder, 4))

    predictor = Predictor(beam_search, input_vocab, output_vocab)
    while True:
        try:
            seq_str = raw_input("Type in a source sequence:")
            seq = seq_str.strip().split()
            results = predictor.predict_n(seq, n=3)
            for i, res in enumerate(results):
                print('option %s: %s\n', i + 1, res)
        except KeyboardInterrupt:
            logging.info("Bye Bye")
            exit(0)
params = {
    'n_layers': 2,
    'hidden_size': 512,
    'src_vocab_size': 15000,
    'tgt_vocab_size': 5000,
    'max_len': 128,
    'rnn_cell': 'lstm',
    'batch_size': opt.batch_size,
    'num_epochs': opt.epochs
}

logging.info(params)

# Prepare dataset
src = SourceField()
tgt = TargetField()
poison_field = torchtext.data.Field(sequential=False, use_vocab=False)
max_len = params['max_len']


def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len


def train_filter(example):
    return len_filter(example)


train, fields, src, tgt, poison_field, idx_field = load_data(
    opt.train_path, filter_func=train_filter)
Esempio n. 29
0
if opt.load_checkpoint is not None:
    logging.info("loading checkpoint from {}".format(
        os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME,
                     opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir,
                                   Checkpoint.CHECKPOINT_DIR_NAME,
                                   opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model
    input_vocab = checkpoint.input_vocab
    output_vocab = checkpoint.output_vocab
else:
    # Prepare dataset
    # SourceField requires that batch_first and include_lengths be true.
    src = SourceField(lower=opt.lower)
    # TargetField requires that batch_first be true as well as prepends <sos> and appends <eos> to sequences.
    tgt = TargetField()
    # Sequence's length cannot exceed max_len.
    max_len = 100

    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    train = torchtext.data.TabularDataset(path=opt.train_path,
                                          format='tsv',
                                          fields=[('src', src), ('tgt', tgt)],
                                          filter_pred=len_filter)
    dev = torchtext.data.TabularDataset(path=opt.dev_path,
                                        format='tsv',
                                        fields=[('src', src), ('tgt', tgt)],
Esempio n. 30
0
if torch.cuda.is_available():
        print("Cuda device set to %i" % opt.cuda_device)
        torch.cuda.set_device(opt.cuda_device)

#################################################################################
# load model

logging.info("loading checkpoint from {}".format(os.path.join(opt.checkpoint_path)))
checkpoint = Checkpoint.load(opt.checkpoint_path)
seq2seq = checkpoint.model
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab

############################################################################
# Prepare dataset and loss
src = SourceField()
tgt = TargetField()
src.vocab = input_vocab
tgt.vocab = output_vocab
max_len = opt.max_len

def len_filter(example):
    return len(example.src) <= max_len and len(example.tgt) <= max_len

# generate test set
test = torchtext.data.TabularDataset(
    path=opt.test_data, format='tsv',
    fields=[('src', src), ('tgt', tgt)],
    filter_pred=len_filter
)