def _preload_data(self,
                      file_path,
                      label_file_path,
                      batch_size=1000,
                      max_len=None):
        seqs = seq_all(file_path)
        codes = seqs["codes"]
        label_index_df = pd.read_pickle(label_file_path)
        piece_labels = label_index_df["label"]
        indexes = label_index_df["index"]
        sample_ids = label_index_df["sample_id"]
        code_pieces = fetch_code_pieces(codes, sample_ids, indexes)

        ast_label_seqs = seq_from_code_ast(seqs)
        sub_codes, labels = ast_label_seqs["sub_code_pieces"], ast_label_seqs[
            self.hparams["snake_params"].label_type]

        # TODO try different sample strategy for sub_codes.
        # code_pieces, piece_labels = sq.concatenate(sub_codes), sq.concatenate(labels)
        # code_pieces = sq.smap(utf8decode, code_pieces)
        tok_codes = sq.smap(tokenize_plus(self.tokenizer, max_len, True),
                            code_pieces)
        tok_piece_labels = sq.smap(label_tokenize(self.label_tokenizer),
                                   piece_labels)
        return sq.collate([tok_codes, tok_piece_labels])
示例#2
0
 def _preload_data(self, file_path, batch_size=1000, max_len=None):
     seqs = seq_all(file_path)
     codes, docs = seqs["codes"], seqs["docs"]
     tok_codes = sq.smap(tokenize_plus(self.tokenizer, max_len, True),
                         codes)
     tok_docs = sq.smap(tokenize_plus(self.tokenizer, max_len, True), docs)
     return sq.collate([tok_codes, tok_docs])
示例#3
0
def test_throughput():
    def proc(x):
        time.sleep(proc.delay)
        return x

    proc.delay = 0.01

    arr = list(range(100))
    monitored_arr = monitor_throughput(smap(proc, arr))
    x = list(monitored_arr)

    assert x == arr
    assert monitored_arr.throughput() - 100 < 1

    monitored_arr.reset()

    with pytest.raises(RuntimeError):
        monitored_arr.read_delay()
    with pytest.raises(RuntimeError):
        monitored_arr.throughput()

    proc.delay = 0.02

    arr = list(range(100))
    monitored_arr = monitor_throughput(smap(proc, arr))
    x = [monitored_arr[i] for i in range(len(monitored_arr))]

    assert x == arr
    assert monitored_arr.read_delay() - 0.02 < 0.002
def seq_from_code_ast(_seq_dict):
    _code_bytes = _seq_dict["code_bytes"]
    #FIXME: for php it will ALMOST contain only 'program' and 'text' (even on playground).
    # fix it by wrapping code bytes with <?php ... ?>
    # and check if there's any exceptions (label counts shows a different view).
    # java need a extra class Test{ ... } wrapper, otherwise it will not compile right.
    _asts = _seq_dict["asts"]
    sub_code_pieces = sq.smap(_sub_code_pieces, _asts, _code_bytes)
    sub_code_indexes = sq.smap(_sub_code_indexes, _asts)
    # sub_asts = sq.smap(_sub_labels, _asts)
    sub_labels = sq.smap(_sub_labels, _asts)
    type_label = sq.smap(
        lambda lbs: [[x[1] for x in labels] for labels in lbs],
        sub_labels
    )
    combined_label = sq.smap(
        lambda lbs: [[f"{x[0]}-{x[1]}" for x in labels] for labels in lbs],
        sub_labels
    )

    _dict_all = locals()
    _dict_return = {k: v for k, v in _dict_all.items()
                    if not k.startswith("_")}
    # print(_dict_return.keys())
    return _dict_return
示例#5
0
def test_prefetch_timing(method):
    def f1(x):
        sleep(.02 + 0.01 * (random.random() - .5))
        return x

    arr = list(range(100))
    y = smap(f1, arr)
    y = prefetch(y, nworkers=2, max_cached=20, method=method, timeout=1)

    t1 = time()
    z = list(y)
    t2 = time()

    assert z == arr
    duration = t2 - t1
    print("test_prefetch_timing({}):1 {}".format(method, duration))
    assert duration < 1.3

    arr = list(range(200))
    y = smap(f1, arr)
    y = prefetch(y,
                 nworkers=2,
                 max_cached=20,
                 method=method,
                 timeout=1,
                 anticipate=lambda i: i + 2)

    t1 = time()
    z = [y[i] for i in range(0, len(y), 2)]
    t2 = time()

    assert z == arr[::2]
    duration = t2 - t1
    print("test_prefetch_timing({}):2 {}".format(method, duration))
    assert duration < 1.3
示例#6
0
def test_prefetch(method):
    def f1(x):
        sleep(0.005 * (1 + random.random()))
        return x

    if method == "process":
        start_hook = random.seed
    else:
        start_hook = None

    arr = list(range(300))
    y = smap(f1, arr)
    y = prefetch(y,
                 nworkers=4,
                 max_buffered=10,
                 method=method,
                 timeout=1,
                 start_hook=start_hook)

    # check if workers are properly restarted when asleep
    i = 0
    n_wakeups = 3
    for _ in range(500):
        if n_wakeups > 0 and random.random() < 0.005:
            sleep(1.1)  # will let worker go to sleep
            n_wakeups -= 1
        value = y[i]
        assert value == arr[i]
        if random.random() < 0.05:
            i = random.randrange(0, len(arr))
        else:
            i = (i + 1) % len(arr)

    # helps with coverage
    y.async_seq._finalize(y.async_seq)

    # overly large buffer
    arr = list(range(10))
    y = smap(f1, arr)
    y = prefetch(y, nworkers=4, max_buffered=50, method=method, timeout=1)
    assert list(y) == arr

    # anticipate method
    arr = list(range(200))
    y = smap(f1, arr)
    y = prefetch(y,
                 nworkers=2,
                 max_buffered=20,
                 method=method,
                 timeout=1,
                 anticipate=lambda i: i + 2)

    z = [y[i] for i in range(0, len(y), 2)]

    assert z == arr[::2]
示例#7
0
 def _preload_data(self, file_path, batch_size=1000, max_len=None):
     seqs = seq_all(file_path)
     codes = seqs["codes"]
     docs = seqs["docs"]
     # TODO try different sample strategy for sub_codes.
     tok_both = sq.smap(tokenize_pair_plus(self.tokenizer, max_len, True),
                        docs, codes)
     # FIXME: here <PAD> included for random mask.
     tok_only = sq.smap(lambda x: x["input_ids"], tok_both)
     tok_piece_labels = sq.smap(
         random_mask(list(self.tokenizer.get_vocab().values())), tok_only)
     return sq.collate([tok_both, tok_piece_labels])
示例#8
0
def test_prefetch_errors(method):
    class CustomError(Exception):
        pass

    def f1(x):
        if x is None:
            raise CustomError()
        else:
            return x

    arr1 = [1, 2, 3, None]
    arr2 = smap(f1, arr1)
    y = prefetch(arr2, nworkers=2, max_cached=2, method=method)

    for i in range(3):
        assert y[i] == arr1[i]
    with pytest.raises(PrefetchException):
        a = y[3]
        del a

    def f2(x):
        if x is None:
            raise ValueError("blablabla")
        else:
            return x

    # helps with coverage
    y._finalize(y)

    arr2 = smap(f2, arr1)
    y = prefetch(arr2, nworkers=2, max_cached=2, method=method)

    for i in range(3):
        assert y[i] == arr1[i]
    try:
        a = y[3]
        del a
    except Exception as e:
        assert isinstance(e, PrefetchException)
        assert isinstance(e.__cause__, ValueError)
    else:
        assert False

    assert y[0] == 1
    assert y[1] == 2

    # helps with coverage
    y._finalize(y)
示例#9
0
def test_prefetch(method):
    def f1(x):
        sleep(0.005 * (1 + random.random()))
        return x

    if method == "process":
        start_hook = None
    else:
        start_hook = None

    arr = list(range(300))
    y = smap(f1, arr)
    y = prefetch(y,
                 nworkers=4,
                 max_cached=10,
                 method=method,
                 timeout=1,
                 start_hook=start_hook)
    # arr = arr[3:-1:2]
    # y = y[3:-1:2]

    i = 0
    n_wakeups = 3
    for _ in range(500):
        if n_wakeups > 0 and random.random() < 0.005:
            sleep(1.1)  # will let worker go to sleep
            n_wakeups -= 1
        assert y[i] == arr[i]
        if random.random() < 0.05:
            i = random.randrange(0, len(arr))
        else:
            i = (i + 1) % len(arr)

    # helps with coverage
    y._finalize(y)
示例#10
0
def test_debug():
    arr = list(range(100))

    def do(i, v):
        del i, v
        do.i += 1

    do.i = 0

    debugged_arr = debug(arr, do)

    assert list(debugged_arr) == arr
    assert do.i == 100
    assert [debugged_arr[i] for i in range(len(debugged_arr))] == arr
    assert do.i == 200

    do.i = 0
    debugged_arr = debug(arr, do, max_calls=3)

    assert list(debugged_arr) == arr
    assert do.i == 3

    def proc(x):
        time.sleep(0.01)
        return x

    do.i = 0
    debugged_arr = debug(smap(proc, arr), do, max_rate=10)

    assert list(debugged_arr) == arr
    assert do.i == 10
示例#11
0
def test_cached():
    def f(x):
        return x

    cache_size = 3
    arr = [random.random() for _ in range(25)]
    z = add_cache(arr, cache_size)

    assert list(z) == arr
    assert list(z[10:]) == arr[10:]
    assert [z[i] for i in range(10)] == arr[:10]

    z[:10] = list(range(0, -10, -1))
    assert list(z[10:]) == arr[10:]
    assert list(z[:10]) == list(range(0, -10, -1))

    y = smap(f, arr)
    z = add_cache(y, cache_size)

    t1 = time()
    for i in range(len(arr)):
        assert z[i] == arr[i]
        for j in range(max(0, i - cache_size + 1), i + 1):
            assert z[j] == arr[j]
    t2 = time()

    duration = t2 - t1
    assert duration < .28
示例#12
0
def test_prefetch_errors(method, evaluation, picklable_err):
    class CustomError(Exception):
        pass

    def f1(x):
        if x is None:
            raise ValueError("blablabla") if picklable_err else CustomError()
        else:
            return x

    arr1 = [1, 2, 3, None]
    arr2 = smap(f1, arr1)
    y = prefetch(arr2, nworkers=2, max_buffered=2, method=method)

    seterr(evaluation)
    if (method == "process" and not picklable_err) or evaluation == "wrap":
        error_t = EvaluationError
    else:
        error_t = ValueError if picklable_err else CustomError

    for i in range(3):
        assert y[i] == arr1[i]
    with pytest.raises(error_t):
        a = y[3]
        del a
示例#13
0
    def make_sequence(self):
        """Build a sequence that looks like a dataloader when iterated over."""
        # shuffling
        if self.batch_sampler:
            batch_indices = list(self.batch_sampler)
            out = seqtools.smap(lambda bi: [self.dataset[i] for i in bi],
                                batch_indices)
        elif self.sampler:
            shuffle_indices = list(self.sampler)
            out = seqtools.gather(self.dataset, shuffle_indices)
        elif self.shuffle:
            shuffle_indices = np.random.permutation(len(self.dataset))
            out = seqtools.gather(self.dataset, shuffle_indices)
        else:
            out = self.dataset

        # batch
        if not self.batch_sampler and self.batch_size is not None:
            out = seqtools.batch(out,
                                 k=self.batch_size,
                                 drop_last=self.drop_last,
                                 collate_fn=self.collate_fn)
        elif self.batch_sampler:
            out = seqtools.smap(self.collate_fn, out)

        # prefetch
        if self.num_workers > 0:
            out = seqtools.prefetch(out,
                                    max_buffered=self.num_workers *
                                    self.prefetch_factor,
                                    nworkers=self.num_workers,
                                    method='process',
                                    start_hook=self.worker_init_fn,
                                    shm_size=self.shm_size)

        # convert into tensors
        out = seqtools.smap(into_tensors, out)

        # pin memory
        if self.pin_memory:
            out = seqtools.smap(pin_tensors_memory, out)
            out = seqtools.prefetch(out,
                                    nworkers=1,
                                    method='thread',
                                    max_buffered=1)

        return out
示例#14
0
def transform_frames(frame_seq, t: Transformation):
    # shorthand notations
    duration = len(frame_seq)
    sx, sy = t.xscale, t.yscale
    rx, ry = t.ref2d
    rz = t.ref3d[2]

    # generate affine transformation matrix
    # triangles_src = np.array([[rx, rx + 1, rx], [ry, ry, ry + 1]], dtype=np.float32)
    triangles_src = np.array([[rx, ry], [rx + 10, ry], [rx, ry + 10]],
                             dtype=np.float32)
    triangles_dst = np.copy(triangles_src)
    x = triangles_dst[:, 0]
    y = triangles_dst[:, 1]

    z_corrections = 1 + t.zshift / (rz + .0001)
    x[...] = (x - rx) / z_corrections + rx
    y[...] = (y - ry) / z_corrections + ry

    x[...] = (x - rx) * sx + rx
    y[...] = (y - ry) * sy + ry

    x[...] = rx + np.cos(t.tilt) * (x - rx) - np.sin(t.tilt) * (y - ry)
    y[...] = ry + np.sin(t.tilt) * (x - rx) + np.cos(t.tilt) * (y - ry)

    tmatrix = cv2.getAffineTransform(triangles_src, triangles_dst)

    # affine frame-wise transformations
    output = seqtools.smap(lambda f: cv2.warpAffine(f, tmatrix, (640, 480)),
                           frame_seq)

    # fliplr
    if t.fliplr:
        output = seqtools.smap(np.fliplr, output)

    # time scale
    if t.tscale != 1:
        output_duration = transform_durations(duration, t)
        indices = np.round(np.linspace(0, duration - 1,
                                       output_duration)).astype(np.int)
        output = seqtools.gather(output, indices)
        if t.tscale > 1:
            output = seqtools.add_cache(output, cache_size=1)

    return output
示例#15
0
def test_smap_exceptions():
    def do(x):
        del x
        raise CustomException

    data = [random.random() for _ in range(100)]
    m = smap(do, data)
    with pytest.raises(CustomException):
        print(m[0])

    with pytest.raises(CustomException):
        next(iter(m))

    with pytest.raises(TypeError):
        smap(None, data)

    with pytest.raises(ValueError):
        smap(do)
示例#16
0
        def target():
            arr = np.random.rand(1000, 10)
            y = smap(f1, arr)
            y = prefetch(y,
                         method=method,
                         max_buffered=40,
                         nworkers=4,
                         start_hook=init_fn)

            for i in range(0, 1000):
                a = y[i]
示例#17
0
def test_prefetch_timings(prefetch_kwargs):
    def f1(x):
        sleep(0.005 * (1 + random.random()))
        return x

    start_hook = random.seed

    arr = np.random.rand(100, 10)
    y = smap(f1, arr)
    y = prefetch(y,
                 nworkers=4,
                 max_buffered=10,
                 start_hook=start_hook,
                 **prefetch_kwargs)
    y = [y_.copy()
         for y_ in y]  # copy needed to release buffers when shm_size>0
    assert_array_equal(np.stack(y), arr)

    # overly large buffer
    arr = np.random.rand(10, 10)
    y = smap(f1, arr)
    y = prefetch(y, nworkers=4, max_buffered=50, **prefetch_kwargs)
    y = [y_.copy() for y_ in y]
    assert_array_equal(np.stack(y), arr)

    # multiple restarts
    arr = np.random.rand(100, 10)
    y = smap(f1, arr)
    y = prefetch(y, nworkers=4, max_buffered=10, **prefetch_kwargs)
    for _ in range(10):
        n = np.random.randint(0, 99)
        for i in range(n):
            assert_array_equal(y[i], arr[i])

    # starvation
    arr = np.random.rand(100, 10)
    y = prefetch(arr, nworkers=2, max_buffered=10, **prefetch_kwargs)
    y[0]
    sleep(2)
    for i in range(1, 100):
        assert_array_equal(y[i], arr[i])
示例#18
0
def test_smap_exceptions(evaluation):
    def do(x):
        del x
        raise CustomException

    data = [random.random() for _ in range(100)]
    m = smap(do, data)

    seterr(evaluation)
    error_t = EvaluationError if evaluation == "wrap" else CustomException

    with pytest.raises(error_t):
        print(m[0])

    with pytest.raises(error_t):
        next(iter(m))

    with pytest.raises(TypeError):
        smap(None, data)

    with pytest.raises(ValueError):
        smap(do)
示例#19
0
def reload():
    global train_subset, val_subset, test_subset, \
        durations, gloss_seqs, pose2d_seqs, pose3d_seqs, frame_seqs

    with open(os.path.join(cachedir, "data.pkl"), 'rb') as f:
        durations, gloss_seqs, rec_mapping, transformations, \
            train_subset, val_subset, test_subset = pkl.load(f)

    segments = np.stack(
        [np.cumsum(durations) - durations,
         np.cumsum(durations)], axis=1)

    pose2d_seqs = seqtools.split(
        np.load(os.path.join(cachedir, "pose2d_seqs.npy"), mmap_mode='r'),
        segments)

    pose3d_seqs = seqtools.split(
        np.load(os.path.join(cachedir, "pose3d_seqs.npy"), mmap_mode='r'),
        segments)

    frame_seqs = seqtools.smap(lambda r: dataset.bgr_frames(r), rec_mapping)
    frame_seqs = seqtools.smap(transform_frames, frame_seqs, transformations)
示例#20
0
def test_smap_basics():
    n = 100
    data = [random.random() for _ in range(n)]

    def do(x):
        do.call_cnt += 1
        return x + 1

    do.call_cnt = 0

    # indexing
    result = smap(do, data)
    assert len(result) == len(data)
    assert do.call_cnt == 0
    assert list(result) == [x + 1 for x in data]
    assert do.call_cnt == n
    assert [result[i] for i in range(len(result))] == [x + 1 for x in data]
    assert list(result[:]) == [x + 1 for x in data]
示例#21
0
def test_prefetch_timing(method):
    def f1(x):
        sleep(.02 + 0.01 * (random.random() - .5))
        return x

    arr = list(range(420))
    y = smap(f1, arr)
    y = prefetch(y, nworkers=2, max_buffered=20, method=method, timeout=1)

    for i in range(20):
        y[i]  # consume first items to eliminate worker startup time
    t1 = time()
    for i in range(20, 420):
        y[i]
    t2 = time()

    duration = t2 - t1
    print("test_prefetch_timing({}) {:.2f}s".format(method, duration))

    assert duration < 4.5
示例#22
0
def test_prefetch_throughput(prefetch_kwargs):  # pragma: no cover
    def f1(x):
        sleep(.02 + 0.01 * (random.random() - .5))
        return x

    arr = np.random.rand(420, 10)
    y = smap(f1, arr)
    y = prefetch(y, nworkers=2, max_buffered=40, **prefetch_kwargs)

    for i in range(20):
        y[i]  # consume first items to eliminate worker startup time

    t1 = time()
    for i in range(20, 420):
        y[i]
    t2 = time()

    duration = t2 - t1
    print("test_prefetch_timing: {:.2f}s".format(duration))

    assert duration < 4.5
示例#23
0
def test_cached_timing():
    def f(x):
        sleep(.01)
        return x

    cache_size = 3
    arr = [random.random() for _ in range(100)]

    y = smap(f, arr)
    z = add_cache(y, cache_size)

    t1 = time()
    for i in range(len(arr)):
        assert z[i] == arr[i]
        for j in range(max(0, i - cache_size + 1), i + 1):
            assert z[j] == arr[j]
    t2 = time()

    duration = t2 - t1
    print("test_cached_timing {:.2f}s".format(duration))

    assert duration < 1.2
示例#24
0
def test_prefetch_errors(error_mode, prefetch_kwargs, picklable_err):
    class CustomError(Exception):
        pass

    def f1(x):
        if x is None:
            raise ValueError("blablabla") if picklable_err else CustomError()
        else:
            return x

    arr1 = [np.random.rand(10), np.random.rand(10), np.random.rand(10), None]
    arr2 = smap(f1, arr1)
    y = prefetch(arr2, nworkers=2, max_buffered=4, **prefetch_kwargs)

    seterr(error_mode)
    if (prefetch_kwargs['method'] != "thread"
            and not picklable_err) or error_mode == "wrap":
        error_t = EvaluationError
    else:
        error_t = ValueError if picklable_err else CustomError

    for i in range(3):
        assert_array_equal(y[i], arr1[i])
    try:
        a = y[3]
    except Exception as e:
        assert type(e) == error_t

    if (prefetch_kwargs['method']
            == "process") and error_mode == "passthrough":

        class CustomObject:  # unpicklable object
            pass

        arr1 = [np.random.rand(10), CustomObject(), np.random.rand(10)]
        y = prefetch(arr1, nworkers=2, max_buffered=4, **prefetch_kwargs)
        with pytest.raises(ValueError):
            y[1]
示例#25
0
def seq_all(_input_path):
    _sample_df = pd.read_pickle(_input_path)
    # index = _sample_df.index
    codes = _sample_df["code"]
    docs = _sample_df["docstring"]
    code_bytes = sq.smap(utf8encode, codes)
    languages = _sample_df["language"]
    parsers = sq.smap(get_parser, languages)
    asts = sq.smap(Parser.parse, parsers, code_bytes)
    code_tokens = _sample_df["code_tokens"]
    doc_tokens = _sample_df["docstring_tokens"]
    code_split_identifiers = sq.smap(indentifier_split, code_tokens)
    code_tokens_with_identifier_split = sq.smap(pd.Series.explode,
                                                code_split_identifiers)
    doc_split_identifiers = sq.smap(indentifier_split, doc_tokens)
    doc_tokens_with_identifier_split = sq.smap(pd.Series.explode,
                                               doc_split_identifiers)
    _dict_all = locals()
    _dict_return = {
        k: v
        for k, v in _dict_all.items() if not k.startswith("_")
    }
    # print(_dict_return.keys())
    return _dict_return
示例#26
0
import pandas as pd
import dataset_seq as ds
import ast_label_pretrain as ap
import seqtools as sq
from itertools import chain

from utils import fetch_snakemake_from_latest_run

try:
    snakemake
except NameError:
    snakemake = fetch_snakemake_from_latest_run(__file__)

seqs_all = ds.seq_all(snakemake.input[0])
seqs_labels = ap.seq_from_code_ast(seqs_all)
sub_code_indexes = seqs_labels["sub_code_indexes"]
type_label = seqs_labels[snakemake.params.label_type]
sample_ids = sq.smap(
    lambda samp_list, samp_index: [samp_index] * len(samp_list),
    sub_code_indexes, range(0, len(sub_code_indexes)))
df = pd.DataFrame({
    "index": chain.from_iterable(sub_code_indexes),
    "sample_id": chain.from_iterable(sample_ids),
    "label": chain.from_iterable(type_label),
})
df = df[df["index"].apply(lambda x: x[0] != x[1])].reset_index(drop=True)
df.to_pickle(snakemake.output[0])
def fetch_code_pieces(codes, sample_ids, indexes):
    piece_full_code = sq.smap(lambda x:codes[x], sample_ids)
    code_pieces = sq.smap(_fetch_sub_code, indexes, piece_full_code)
    return code_pieces
示例#28
0
from __future__ import print_function
import time
import seqtools

files = [
    'file1', 'file2', 'file3', 'file4', 'file5', 'file6', 'file7', 'file8',
    'file9', 'file10'
]


def load(some_file):
    time.sleep(.1)
    return list(range(10) if some_file == 'file10' else range(200))


loaded_files = seqtools.smap(load, files)
loaded_files = seqtools.add_cache(loaded_files, 2)
all_samples = seqtools.unbatch(loaded_files, 200, 10)


def preprocess(x):
    t = time.clock()
    while time.clock() - t < 0.005:
        pass  # busy waiting
    return x


preprocessed_samples = seqtools.smap(preprocess, all_samples)
minibatches = seqtools.batch(preprocessed_samples, 64, collate_fn=list)

t1 = time.time()
示例#29
0
def prepare():
    global train_subset, val_subset, test_subset, \
        durations, gloss_seqs, pose2d_seqs, pose3d_seqs

    # Create temporary directory
    if not os.path.exists(cachedir):
        os.mkdir(cachedir)

    # Load data
    train_subset, val_subset, test_subset = dataset.default_splits()
    pose2d_seqs = [dataset.positions(i) for i in range(len(dataset))]
    pose3d_seqs = [dataset.positions_3d(i) for i in range(len(dataset))]

    # Eliminate strange gloss annotations
    gloss_seqs_train = [dataset.glosses(r) for r in train_subset]
    rejected = set()
    for r, gseq in zip(train_subset, gloss_seqs_train):
        for i in range(len(gseq) - 1):
            if gseq[i + 1, 1] - gseq[i, 2] < 0:
                rejected.add(r)
    train_subset = np.setdiff1d(train_subset, rejected)
    if len(rejected) > 0:
        logging.warning(
            "Eliminated sequences with invalid glosses: {}".format(rejected))

    # Interpolate missing poses and eliminate deteriorated training sequences
    invalid_masks = seqtools.smap(detect_invalid_pts, pose2d_seqs)
    pose2d_seqs = seqtools.smap(interpolate_positions, pose2d_seqs,
                                invalid_masks)
    pose3d_seqs = seqtools.smap(interpolate_positions, pose3d_seqs,
                                invalid_masks)

    rejected = np.where(
        [np.mean(im[:, important_joints]) > .15 for im in invalid_masks])[0]
    train_subset = np.setdiff1d(train_subset, rejected)
    if len(rejected) > 0:
        logging.warning(
            "eliminated {} sequences with missing positions".format(
                len(rejected)))

    # Default preprocessing
    ref2d = seqtools.add_cache(seqtools.smap(get_ref_pts, pose2d_seqs),
                               cache_size=1)
    ref3d = seqtools.add_cache(seqtools.smap(get_ref_pts, pose3d_seqs),
                               cache_size=1)

    transformations = np.rec.array(
        [(r2, r3, False, 0, tgt_dist - r3[2], 1., 1., 1., 1.)
         for r2, r3 in zip(ref2d, ref3d)],
        dtype=TransformationType)

    # Precompute transformations for augmentation of the training set
    original_train_subset = train_subset

    rec_mapping = np.arange(len(dataset))
    for _ in range(5 - 1):
        offset = len(rec_mapping)
        new_subset = np.arange(offset, offset + len(original_train_subset))

        newt = np.repeat(transformations[0], len(new_subset),
                         axis=0).view(np.recarray)
        newt.fliplr = uniform(size=len(newt)) < 0.15
        newt.tilt += uniform(-7, 7, size=len(newt)) * np.pi / 180
        newt.xscale += uniform(.85, 1.15, size=len(newt))
        newt.yscale += uniform(.85, 1.15, size=len(newt))
        newt.zscale += uniform(.85, 1.15, size=len(newt))
        newt.tscale += uniform(.85, 1.15, size=len(newt))

        rec_mapping = np.concatenate([rec_mapping, original_train_subset])
        transformations = np.concatenate([transformations,
                                          newt]).view(np.recarray)
        train_subset = np.concatenate([train_subset, new_subset])

    # Apply transformations (if they are cheap to compute)
    durations = np.array([
        transform_durations(dataset.durations(r), t)
        for r, t in zip(rec_mapping, transformations)
    ])

    gloss_seqs = [
        transform_glosses(dataset.glosses(r), dataset.durations(r), t)
        for r, t in zip(rec_mapping, transformations)
    ]

    pose2d_seqs = seqtools.gather(pose2d_seqs, rec_mapping)
    pose2d_seqs = seqtools.smap(
        partial(transform_pose2d, flip_mapping=flip_mapping, frame_width=640),
        pose2d_seqs, transformations)

    pose3d_seqs = seqtools.gather(pose3d_seqs, rec_mapping)
    pose3d_seqs = seqtools.smap(
        partial(transform_pose3d, flip_mapping=flip_mapping), pose3d_seqs,
        transformations)

    # Export
    np.save(os.path.join(cachedir, "pose2d_seqs.npy"),
            seqtools.concatenate(pose2d_seqs))
    np.save(os.path.join(cachedir, "pose3d_seqs.npy"),
            seqtools.concatenate(pose3d_seqs))

    with open(os.path.join(cachedir, "data.pkl"), 'wb') as f:
        pkl.dump((durations, gloss_seqs, rec_mapping, transformations,
                  train_subset, val_subset, test_subset), f)
示例#30
0
        return query_embeddings, code_embeddings, losses, reiprocal_rank, mrr


#train, test and evaluation.

ds_train = seq_all(input.train_data)

code_tokens = ds_train["code_tokens_with_identifier_split"]
doc_tokens = ds_train["doc_tokens_with_identifier_split"]

code_token_counter = count_all(code_tokens)
doc_token_counter = count_all(doc_token_counter)

import seqtools as sq

code_tokenized = sq.smap(tokenize(code_token_counter), code_tokens)
doc_tokenized = sq.smap(tokenize(doc_token_counter), doc_tokens)

code_pad = sq.smap(pad_to_1d(code_token_counter, 200), code_tokenized)
doc_pad = sq.smap(pad_to_1d(doc_token_counter, 200), doc_tokenized)

#batch and pad dataset.

model = CodeQuerySoftmaxBertModel(code_token_counter, doc_token_counter)

opt = torch.optim.Adam(model.parameters())

from tqdm import trange, tqdm

train_data = sq.collate(doc_pad, code_pad)