コード例 #1
0
def encode_all_types(df_ret: pd.DataFrame, df_params: pd.DataFrame,
                     df_vars: pd.DataFrame, output_dir: str):
    all_types = np.concatenate(
        (df_ret['return_type'].values, df_params['arg_type'].values,
         df_vars['var_type'].values),
        axis=0)
    le_all = LabelEncoder()
    le_all.fit(all_types)
    df_ret['return_type_enc_all'] = le_all.transform(
        df_ret['return_type'].values)
    df_params['arg_type_enc_all'] = le_all.transform(
        df_params['arg_type'].values)
    df_vars['var_type_enc_all'] = le_all.transform(df_vars['var_type'].values)

    unq_types, count_unq_types = np.unique(all_types, return_counts=True)
    pd.DataFrame(list(
        zip(le_all.transform(unq_types),
            [unq_types[i] for i in np.argsort(count_unq_types)[::-1]],
            [count_unq_types[i] for i in np.argsort(count_unq_types)[::-1]])),
                 columns=['enc', 'type', 'count']).to_csv(os.path.join(
                     output_dir, "_most_frequent_all_types.csv"),
                                                          index=False)

    logger.info(f"Total no. of extracted types: {len(all_types):,}")
    logger.info(f"Total no. of unique types: {len(unq_types):,}")

    return df_ret, df_params, le_all
コード例 #2
0
def load_model_params(params_file_path: str = None) -> dict:

    if params_file_path is not None:
        logger.info(
            "Loading user-provided hyper-parameters for the Type4Py model...")
        return load_json(params_file_path)
    else:
        return load_json(
            pkg_resources.resource_filename(__name__, 'model_params.json'))
コード例 #3
0
 def light_assess(self, fpath):
     logger.info(f"Light assessing {fpath}.")
     try:
         #self._check_basics(fpath)
         retcode, outlines = self._type_check(fpath)
         self._check_tc_outcome(retcode, outlines)
         logger.info("Passed the light assessment.")
         return True
     except CustomError as e:
         logger.error(str(e))
         return False
     except CustomWarning as e:
         logger.warning(str(e))
         return False
コード例 #4
0
ファイル: data_loaders.py プロジェクト: saltudelft/type4py
def load_test_data_per_model(data_loading_funcs: dict, output_path: str,
                             no_batches: int, drop_last_batch:bool=False):
    """
    Loads appropriate training data based on the model's type
    """

    load_data_t = time()
    if data_loading_funcs['name'] == 'woi':
        # without identifiers
        X_tok_test, X_type_test, t_idx = data_loading_funcs['test'](output_path)
        _, _, Y_all_test = data_loading_funcs['labels'](output_path)


        triplet_data_test = TripletDataset(X_tok_test, X_type_test, labels=Y_all_test,
                                           dataset_name=data_loading_funcs['name'], train_mode=False)
    
    elif data_loading_funcs['name'] == 'woc':
        # without code tokens
        X_id_test, X_type_test, t_idx = data_loading_funcs['test'](output_path)
        _, _, Y_all_test = data_loading_funcs['labels'](output_path)


        triplet_data_test = TripletDataset(X_id_test, X_type_test, labels=Y_all_test,
                                           dataset_name=data_loading_funcs['name'], train_mode=False)

    elif data_loading_funcs['name'] == 'wov':
        # without visible type hints
        X_id_test, X_tok_test, t_idx = data_loading_funcs['test'](output_path)
        _, _, Y_all_test = data_loading_funcs['labels'](output_path)


        triplet_data_test = TripletDataset(X_id_test, X_tok_test, labels=Y_all_test,
                                           dataset_name=data_loading_funcs['name'], train_mode=False)
        
    else:
        # Complete model
        X_id_test, X_tok_test, X_type_test, t_idx = data_loading_funcs['test'](output_path)
        _, _, Y_all_test = data_loading_funcs['labels'](output_path)


        triplet_data_test = TripletDataset(X_id_test, X_tok_test, X_type_test, labels=Y_all_test,
                                           dataset_name=data_loading_funcs['name'], train_mode=False)


    logger.info(f"Loaded the test set of the {data_loading_funcs['name']} dataset in {(time()-load_data_t)/60:.2f} min")

    return DataLoader(triplet_data_test, batch_size=no_batches, num_workers=12, drop_last=drop_last_batch), t_idx
コード例 #5
0
def preprocess_parametric_types(
        df_param: pd.DataFrame, df_ret: pd.DataFrame, df_vars: pd.DataFrame
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Reduces the depth of parametric types
    """
    from libcst import parse_module, ParserSyntaxError
    global s
    s = 0

    def reduce_depth_param_type(t: str) -> str:
        global s
        if regex.match(r'.+\[.+\]', t):
            try:
                t = parse_module(t)
                t = t.visit(
                    ParametricTypeDepthReducer(
                        max_annot_depth=MAX_PARAM_TYPE_DEPTH))
                return t.code
            except ParserSyntaxError:
                try:
                    t = lenient_parse_module(t)
                    t = t.visit(
                        ParametricTypeDepthReducer(
                            max_annot_depth=MAX_PARAM_TYPE_DEPTH))
                    s += 1
                    return t.code
                except ParserSyntaxError:
                    return None
        else:
            return t

    df_param['arg_type'] = df_param['arg_type'].progress_apply(
        reduce_depth_param_type)
    df_ret['return_type'] = df_ret['return_type'].progress_apply(
        reduce_depth_param_type)
    df_vars['var_type'] = df_vars['var_type'].progress_apply(
        reduce_depth_param_type)
    logger.info(f"Sucssesfull lenient parsing {s}")

    return df_param, df_ret, df_vars
コード例 #6
0
def type4py_to_onnx(args):
    type4py_model = torch.load(join(args.o, "type4py_complete_model.pt")).model
    type4py_model.eval()
    logger.info("Loaded the pre-trained Type4Py model")

    x_id, x_tok, x_avl = torch.randn(BATCH_SIZE, 31, 100, requires_grad=True).to(DEVICE), torch.randn(BATCH_SIZE, 21, 100, requires_grad=True).to(DEVICE), \
                         torch.randint(low=0, high=2, size=(BATCH_SIZE, 1024), dtype=torch.float32, requires_grad=True).to(DEVICE)
    
    t_out = type4py_model(x_id, x_tok, x_avl)
    
    torch.onnx.export(type4py_model, (x_id, x_tok, x_avl), join(args.o, "type4py_complete_model.onnx"),
                      export_params=True, do_constant_folding=True, input_names = ['id', 'tok', 'avl'], output_names = ['output'],
                      dynamic_axes={'input' : {0 : 'batch_size'}, 'output' : {0 : 'batch_size'}})
    logger.info("Exported the pre-trained Type4Py model to an ONNX model")

    type4py_onnx_m = onnx.load(join(args.o, "type4py_complete_model.onnx"))
    onnx.checker.check_model(type4py_onnx_m)

    ort_session = onnxruntime.InferenceSession(join(args.o, "type4py_complete_model.onnx"))
    ort_inputs =  {ort_session.get_inputs()[0].name: to_numpy(x_id),
                   ort_session.get_inputs()[1].name: to_numpy(x_tok),
                   ort_session.get_inputs()[2].name: to_numpy(x_avl)}
    ort_outs = ort_session.run(None, ort_inputs)

    np.testing.assert_allclose(to_numpy(t_out), ort_outs[0], rtol=1e-03, atol=1e-05)
    logger.info("The exported Type4Py model has been tested with ONNXRuntime, and the result looks good!")
コード例 #7
0
    def train_model(self, corpus_iterator: TokenIterator,
                    model_path_name: str) -> None:
        """
        Train a Word2Vec model and save the output to a file.
        :param corpus_iterator: class that can provide an iterator that goes through the corpus
        :param model_path_name: path name of the output file
        """

        w2v_model = Word2Vec(min_count=5,
                             window=5,
                             vector_size=W2V_VEC_LENGTH,
                             workers=multiprocessing.cpu_count())

        t = time()
        w2v_model.build_vocab(corpus_iterable=corpus_iterator)
        logger.info('Built W2V vocab in {} mins'.format(
            round((time() - t) / 60, 2)))
        logger.info(f"W2V model's vocab size: {len(w2v_model.wv):,}")

        t = time()
        w2v_model.train(corpus_iterable=corpus_iterator,
                        total_examples=w2v_model.corpus_count,
                        epochs=20,
                        report_delay=1)

        logger.info('Built W2V model in {} mins'.format(
            round((time() - t) / 60, 2)))
        w2v_model.save(model_path_name)
コード例 #8
0
def evaluate(output_path: str,
             data_name: str,
             tasks: set,
             top_n: int = 10,
             mrr_all=False):

    logger.info(
        f"Evaluating the Type4Py {data_name} model for {tasks} prediction task"
    )
    logger.info(
        f"*************************************************************************"
    )
    # Loading label encoder andd common types
    test_pred = load_json(
        join(output_path, f'type4py_{data_name}_test_predictions.json'))
    le_all = pickle.load(open(join(output_path, "label_encoder_all.pkl"),
                              'rb'))
    common_types = pickle.load(
        open(join(output_path, "complete_common_types.pkl"), 'rb'))
    common_types = set(le_all.inverse_transform(list(common_types)))
    #ubiquitous_types = {'str', 'int', 'list', 'bool', 'float'}
    #common_types = common_types - ubiquitous_types

    eval_pred_dsl(test_pred, common_types, tasks, top_n=top_n, mrr_all=mrr_all)
コード例 #9
0
def filter_var_wo_type(df_vars: pd.DataFrame) -> pd.DataFrame:
    """
    Filters out variables without a type
    """
    df_var_len = len(df_vars)
    logger.info(f"Variables before dropping: {len(df_vars):,}")
    df_vars = df_vars[df_vars['var_type'].notnull()]
    logger.info(f"Variables after dropping dropping: {len(df_vars):,}")
    logger.info(
        f"Filtered out {df_var_len - len(df_vars):,} variables w/o a type.")

    return df_vars
コード例 #10
0
def filter_functions(
        df: pd.DataFrame,
        funcs=['str', 'unicode', 'repr', 'len', 'doc',
               'sizeof']) -> pd.DataFrame:
    """
    Filters functions which are not useful.
    :param df: dataframe to use
    :return: filtered dataframe
    """

    df_len = len(df)
    logger.info(f"Functions before dropping on __*__ methods {len(df):,}")
    df = df[~df['name'].isin(funcs)]
    logger.info(f"Functions after dropping on __*__ methods {len(df):,}")
    logger.info(f"Filtered out {df_len - len(df):,} functions.")

    return df
コード例 #11
0
def train_loop_dsl(model: TripletModel, criterion, optimizer,
                   train_data_loader: DataLoader,
                   valid_data_loader: DataLoader, learning_rate: float,
                   epochs: int, ubiquitous_types: str, common_types: set,
                   model_path: str):
    from type4py.predict import predict_type_embed

    for epoch in range(1, epochs + 1):
        model.train()
        #epoch_start_t = time()
        total_loss = 0

        for batch_i, (anchor, positive_ex, negative_ex) in enumerate(
                tqdm(train_data_loader,
                     total=len(train_data_loader),
                     desc=f"Epoch {epoch}")):
            anchor, _ = anchor[0], anchor[1]
            positive_ex, _ = positive_ex[0], positive_ex[1]
            negative_ex, _ = negative_ex[0], negative_ex[1]

            optimizer.zero_grad()
            anchor_embed, positive_ex_embed, negative_ex_embed = model(
                anchor, positive_ex, negative_ex)
            loss = criterion(anchor_embed, positive_ex_embed,
                             negative_ex_embed)

            # Backward and optimize
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        logger.info(f"epoch: {epoch} train loss: {total_loss}")

        if valid_data_loader is not None:
            if epoch % 5 == 0:
                logger.info("Evaluating on validation set")
                valid_start = time()
                valid_loss, valid_all_acc = compute_validation_loss_dsl(
                    model, criterion, train_data_loader, valid_data_loader,
                    predict_type_embed, ubiquitous_types, common_types)
                logger.info(
                    f"epoch: {epoch} valid loss: {valid_loss} in {(time() - valid_start) / 60.0:.2f} min."
                )
コード例 #12
0
def filter_variables(df_vars: pd.DataFrame,
                     types=[
                         'Any', 'None', 'object', 'type', 'Type[Any]',
                         'Type[cls]', 'Type[type]', 'Type', 'TypeVar',
                         'Optional[Any]'
                     ]):
    """
    Filters out variables with specified types such as Any or None
    """

    df_var_len = len(df_vars)
    logger.info(
        f"Variables before dropping on {','.join(types)}: {len(df_vars):,}")
    df_vars = df_vars[~df_vars['var_type'].isin(types)]
    logger.info(
        f"Variables after dropping on {','.join(types)}: {len(df_vars):,}")
    logger.info(f"Filtered out {df_var_len - len(df_vars):,} variables.")

    return df_vars
コード例 #13
0
ファイル: predict.py プロジェクト: saltudelft/type4py
def test(output_path: str,
         data_loading_funcs: dict,
         type_vocab_limit: int = None):

    logger.info(f"Testing Type4Py model")
    logger.info(
        f"**********************************************************************"
    )
    # Loading dataset
    logger.info("Loading train and test sets...")

    # Model's hyper parameters
    model_params = load_model_params()
    train_data_loader, valid_data_loader = load_training_data_per_model(
        data_loading_funcs,
        output_path,
        model_params['batches_test'],
        train_mode=False)

    model = torch.load(
        join(output_path, f"type4py_{data_loading_funcs['name']}_model.pt"))
    logger.info(
        f"Loaded the pre-trained Type4Py {data_loading_funcs['name']} model")
    logger.info(
        f"Type4Py's trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}"
    )

    le_all = pickle.load(open(join(output_path, "label_encoder_all.pkl"),
                              'rb'))
    type_vocab = pd.read_csv(join(
        output_path, '_most_frequent_all_types.csv')).head(
            type_vocab_limit if type_vocab_limit is not None else -1)
    type_vocab = set(le_all.transform(type_vocab['type'].values))

    annoy_index, embed_labels = build_type_clusters(model.model,
                                                    train_data_loader,
                                                    valid_data_loader,
                                                    type_vocab)
    logger.info("Created type clusters")

    annoy_index.save(
        join(output_path,
             f"type4py_{data_loading_funcs['name']}_type_cluster"))
    np.save(
        join(output_path, f"type4py_{data_loading_funcs['name']}_true.npy"),
        embed_labels)
    logger.info("Saved type clusters")

    test_data_loader, t_idx = load_test_data_per_model(
        data_loading_funcs, output_path, model_params['batches_test'])
    logger.info("Mapping test samples to type clusters")
    test_type_embed, embed_test_labels = compute_type_embed_batch(
        model.model, test_data_loader)

    # Perform KNN search and predict
    logger.info("Performing KNN search")

    train_valid_labels = le_all.inverse_transform(embed_labels)
    embed_test_labels = le_all.inverse_transform(embed_test_labels)
    pred_types = predict_type_embed_task(test_type_embed, embed_test_labels,
                                         train_valid_labels, t_idx,
                                         annoy_index, model_params['k'])

    save_json(
        join(output_path,
             f"type4py_{data_loading_funcs['name']}_test_predictions.json"),
        pred_types)
    logger.info("Saved the Type4Py model's predictions on the disk")
コード例 #14
0
def train(output_path: str,
          data_loading_funcs: dict,
          model_params_path=None,
          validation: bool = False):

    logger.info(f"Training Type4Py model")
    logger.info(
        f"***********************************************************************"
    )

    # Model's hyper parameters
    model_params = load_model_params(model_params_path)
    train_data_loader, valid_data_loader = load_training_data_per_model(
        data_loading_funcs,
        output_path,
        model_params['batches'],
        no_workers=cpu_count() // 2)

    # Loading label encoder and finding ubiquitous & common types
    le_all = pickle.load(open(join(output_path, "label_encoder_all.pkl"),
                              'rb'))
    count_types = Counter(train_data_loader.dataset.labels.data.numpy())
    common_types = [
        t.item() for t in train_data_loader.dataset.labels
        if count_types[t.item()] >= 100
    ]
    ubiquitous_types = set(
        le_all.transform(['str', 'int', 'list', 'bool', 'float']))
    common_types = set(common_types) - ubiquitous_types

    logger.info("Percentage of ubiquitous types: %.2f%%" % (len([t.item() for t in \
        train_data_loader.dataset.labels if t.item() in ubiquitous_types]) / train_data_loader.dataset.labels.shape[0]*100.0))
    logger.info("Percentage of common types: %.2f%%" % (len([t.item() for t in \
        train_data_loader.dataset.labels if t.item() in common_types]) / train_data_loader.dataset.labels.shape[0]*100.0))

    with open(
            join(output_path,
                 f"{data_loading_funcs['name']}_common_types.pkl"), 'wb') as f:
        pickle.dump(common_types, f)

    # Loading the model
    model = load_model(data_loading_funcs['name'], model_params)
    logger.info(f"Intializing the {model.__class__.__name__} model")
    model = TripletModel(model).to(DEVICE)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    criterion = torch.nn.TripletMarginLoss(margin=model_params['margin'])
    optimizer = torch.optim.Adam(model.parameters(), lr=model_params['lr'])

    train_t = time()
    train_loop_dsl(model, criterion, optimizer, train_data_loader,
                   valid_data_loader if validation else None,
                   model_params['lr'], model_params['epochs'],
                   ubiquitous_types, common_types, None)
    logger.info("Training finished in %.2f min" % ((time() - train_t) / 60))

    # Saving the model
    logger.info(
        "Saved the trained Type4Py model for %s prediction on the disk" %
        data_loading_funcs['name'])
    torch.save(
        model.module if torch.cuda.device_count() > 1 else model,
        join(output_path, f"type4py_{data_loading_funcs['name']}_model.pt"))
コード例 #15
0
def eval_pred_dsl(test_pred: List[dict],
                  common_types,
                  tasks: set,
                  top_n=10,
                  mrr_all=False):
    """
    Computes evaluation metrics such as recall, precision and f1-score
    """
    param_type_match = r'(.+)\[(.+)\]'

    def pred_types_fix(y_true: str, y_pred: List[Tuple[str, int]]):
        for i, (p, _) in enumerate(y_pred[:top_n]):
            if p == y_true:
                return p, 1 / (i + 1)

        return y_pred[0][0], 0.0

    def is_param_correct(true_param_type: str, pred_types: np.array):
        no_match = 0
        r = 0.0
        for i, p in enumerate(pred_types):
            if re.match(param_type_match, p):
                if re.match(param_type_match,
                            true_param_type).group(1) == re.match(
                                param_type_match, p).group(1):
                    no_match += 1
                    r = 1 / (i + 1)
                    break
            else:
                if re.match(param_type_match,
                            true_param_type).group(1).lower() == p.lower():
                    no_match += 1
                    r = 1 / (i + 1)
                    break

        return no_match, r

    #ubiquitous_types = {'str', 'int', 'list', 'bool', 'float'}
    ubiquitous_types = {
        'str', 'int', 'list', 'bool', 'float', 'typing.Text', 'typing.List',
        'typing.List[typing.Any]', 'typing.list'
    }
    #common_types = common_types - ubiquitous_types

    all_ubiq_types = 0
    corr_ubiq_types = 0
    all_common_types = 0
    corr_common_types = 0
    all_rare_types = 0
    corr_rare_types = 0

    all_param_common_types = 0
    corr_param_common_types = 0
    all_param_rare_types = 0
    corr_param_rare_types = 0

    mrr = []
    mrr_exact_ubiq = []
    mrr_exact_comm = []
    mrr_exact_rare = []

    mrr_param_ubiq = []
    mrr_param_comm = []
    mrr_param_rare = []

    for p in tqdm(test_pred, total=len(test_pred)):

        if p['task'] not in tasks:
            continue

        top_n_pred, r = pred_types_fix(p['original_type'], p['predictions'])
        mrr.append(r)

        if p['original_type'] in ubiquitous_types:
            all_ubiq_types += 1
            mrr_exact_ubiq.append(r)
            if p['original_type'] == top_n_pred:
                corr_ubiq_types += 1
        elif p['original_type'] in common_types:
            all_common_types += 1
            mrr_exact_comm.append(r)
            if p['original_type'] == top_n_pred:
                corr_common_types += 1
            elif p['is_parametric']:
                m, pr = is_param_correct(
                    p['original_type'],
                    [i for i, _ in p['predictions'][:top_n]])
                mrr_param_comm.append(pr)
                corr_param_common_types += m
            # else:
            #     mrr_exact_comm.append(r)
        else:
            all_rare_types += 1
            mrr_exact_rare.append(r)
            if p['original_type'] == top_n_pred:
                corr_rare_types += 1
            elif p['is_parametric']:
                m, pr = is_param_correct(
                    p['original_type'],
                    [i for i, _ in p['predictions'][:top_n]])
                mrr_param_rare.append(pr)
                corr_param_rare_types += m
            # else:
            #     mrr_exact_rare.append(r)

    tasks = 'Combined' if tasks == {'Parameter', 'Return', 'Variable'
                                    } else list(tasks)[0]
    logger.info(
        f"Type4Py - {tasks} - Exact match - all: {(corr_ubiq_types + corr_common_types + corr_rare_types) / (all_ubiq_types+all_common_types+all_rare_types) * 100.0:.1f}%"
    )
    logger.info(
        f"Type4Py - {tasks} - Exact match - ubiquitous: {corr_ubiq_types / all_ubiq_types * 100.0:.1f}%"
    )
    logger.info(
        f"Type4Py - {tasks} - Exact match - common: {corr_common_types / all_common_types * 100.0:.1f}%"
    )
    logger.info(
        f"Type4Py - {tasks} - Exact match - rare: {corr_rare_types / all_rare_types * 100.0:.1f}%"
    )

    logger.info(
        f"Type4Py - {tasks} - Parametric match - all: {(corr_ubiq_types + corr_common_types + corr_rare_types + corr_param_common_types + corr_param_rare_types) / (all_ubiq_types+all_common_types+all_rare_types) * 100.0:.1f}%"
    )
    logger.info(
        f"Type4Py - {tasks} - Parametric match - common: {(corr_param_common_types + corr_common_types) / all_common_types * 100.0:.1f}%"
    )
    logger.info(
        f"Type4Py - {tasks} - Parametric match - rare: {(corr_param_rare_types+corr_rare_types) / all_rare_types * 100.0:.1f}%"
    )

    logger.info(f"Type4Py - Mean reciprocal rank {np.mean(mrr)*100:.1f}")

    if mrr_all:
        logger.info(
            f"Type4Py - {tasks} - MRR - Exact match - all: {np.mean(mrr)*100:.1f}"
        )
        logger.info(
            f"Type4Py - {tasks} - MRR - Exact match - ubiquitous: {np.mean(mrr_exact_ubiq)*100:.1f}"
        )
        logger.info(
            f"Type4Py - {tasks} - MRR - Exact match - common: {np.mean(mrr_exact_comm)*100:.1f}"
        )
        logger.info(
            f"Type4Py - {tasks} - MRR - Exact match - rare: {np.mean(mrr_exact_rare)*100:.1f}"
        )
        #print(mrr_param_comm)
        logger.info(
            f"Type4Py - {tasks} - MRR - Parameteric match - all: {np.mean(mrr_exact_ubiq+mrr_exact_comm+mrr_exact_rare+mrr_param_comm+mrr_param_rare)*100:.1f}"
        )
        logger.info(
            f"Type4Py - {tasks} - MRR - Parameteric match - common: {np.mean(mrr_param_comm+mrr_exact_comm)*100:.1f}"
        )
        logger.info(
            f"Type4Py - {tasks} - MRR - Parameteric match - rare: {np.mean(mrr_param_rare+mrr_exact_rare)*100:.1f}"
        )

    return np.mean(mrr) * 100
コード例 #16
0
def vectorize_args_ret(output_path: str):
    """
    Creates vector representation of functions' arguments and return values
    """

    train_param_df = pd.read_csv(os.path.join(output_path,
                                              "_ml_param_train.csv"),
                                 na_filter=False)
    train_return_df = pd.read_csv(os.path.join(output_path,
                                               "_ml_ret_train.csv"),
                                  na_filter=False)
    train_var_df = pd.read_csv(os.path.join(output_path, "_ml_var_train.csv"),
                               na_filter=False)
    logger.info("Loaded the training data")

    valid_param_df = pd.read_csv(os.path.join(output_path,
                                              "_ml_param_valid.csv"),
                                 na_filter=False)
    valid_return_df = pd.read_csv(os.path.join(output_path,
                                               "_ml_ret_valid.csv"),
                                  na_filter=False)
    valid_var_df = pd.read_csv(os.path.join(output_path, "_ml_var_valid.csv"),
                               na_filter=False)
    logger.info("Loaded the validation data")

    test_param_df = pd.read_csv(os.path.join(output_path,
                                             "_ml_param_test.csv"),
                                na_filter=False)
    test_return_df = pd.read_csv(os.path.join(output_path, "_ml_ret_test.csv"),
                                 na_filter=False)
    test_var_df = pd.read_csv(os.path.join(output_path, "_ml_var_test.csv"),
                              na_filter=False)
    logger.info("Loaded the test data")

    if not os.path.exists(os.path.join(output_path, 'w2v_token_model.bin')):
        embedder = W2VEmbedding(
            train_param_df, train_return_df, train_var_df,
            os.path.join(output_path, 'w2v_token_model.bin'))
        embedder.train_token_model()
    else:
        logger.warn("Loading an existing pre-trained W2V model!")

    w2v_token_model = Word2Vec.load(
        os.path.join(output_path, 'w2v_token_model.bin'))

    # Create dirs for vectors
    mk_dir_not_exist(os.path.join(output_path, "vectors"))
    mk_dir_not_exist(os.path.join(output_path, "vectors", "train"))
    mk_dir_not_exist(os.path.join(output_path, "vectors", "valid"))
    mk_dir_not_exist(os.path.join(output_path, "vectors", "test"))

    #tks_seq_len = (7, 3)
    vts_seq_len = (15, 5)
    # Vectorize functions' arguments
    id_trans_func_param = lambda row: IdentifierSequence(
        w2v_token_model, row.arg_name, row.other_args, row.func_name, None)
    token_trans_func_param = lambda row: TokenSequence(
        w2v_token_model, TOKEN_SEQ_LEN[0], TOKEN_SEQ_LEN[
            1], row.arg_occur, None, None)

    # Identifiers
    logger.info("[arg][identifiers] Generating vectors")
    process_datapoints(train_param_df,
                       os.path.join(output_path, "vectors", "train"),
                       'identifiers_', 'param_train', id_trans_func_param)
    process_datapoints(valid_param_df,
                       os.path.join(output_path, "vectors", "valid"),
                       'identifiers_', 'param_valid', id_trans_func_param)
    process_datapoints(test_param_df,
                       os.path.join(output_path, "vectors", "test"),
                       'identifiers_', 'param_test', id_trans_func_param)

    # Tokens
    logger.info("[arg][code tokens] Generating vectors")
    process_datapoints(train_param_df,
                       os.path.join(output_path, "vectors", "train"),
                       'tokens_', 'param_train', token_trans_func_param)
    process_datapoints(valid_param_df,
                       os.path.join(output_path, "vectors", "valid"),
                       'tokens_', 'param_valid', token_trans_func_param)
    process_datapoints(test_param_df,
                       os.path.join(output_path, "vectors", "test"), 'tokens_',
                       'param_test', token_trans_func_param)

    # Vectorize functions' return types
    id_trans_func_ret = lambda row: IdentifierSequence(
        w2v_token_model, None, row.arg_names_str, row.name, None)
    token_trans_func_ret = lambda row: TokenSequence(
        w2v_token_model, TOKEN_SEQ_LEN[0], TOKEN_SEQ_LEN[
            1], None, row.return_expr_str, None)

    # Identifiers
    logger.info("[ret][identifiers] Generating vectors")
    process_datapoints(train_return_df,
                       os.path.join(output_path, "vectors", "train"),
                       'identifiers_', 'ret_train', id_trans_func_ret)
    process_datapoints(valid_return_df,
                       os.path.join(output_path, "vectors", "valid"),
                       'identifiers_', 'ret_valid', id_trans_func_ret)
    process_datapoints(test_return_df,
                       os.path.join(output_path, "vectors", "test"),
                       'identifiers_', 'ret_test', id_trans_func_ret)

    # Tokens
    logger.info("[ret][code tokens] Generating vectors")
    process_datapoints(train_return_df,
                       os.path.join(output_path, "vectors", "train"),
                       'tokens_', 'ret_train', token_trans_func_ret)
    process_datapoints(valid_return_df,
                       os.path.join(output_path, "vectors", "valid"),
                       'tokens_', 'ret_valid', token_trans_func_ret)
    process_datapoints(test_return_df,
                       os.path.join(output_path, "vectors", "test"), 'tokens_',
                       'ret_test', token_trans_func_ret)

    # Vectorize variables types
    id_trans_func_var = lambda row: IdentifierSequence(
        w2v_token_model, None, None, None, row.var_name)
    token_trans_func_var = lambda row: TokenSequence(
        w2v_token_model, TOKEN_SEQ_LEN[0], TOKEN_SEQ_LEN[
            1], None, None, row.var_occur)

    # Identifiers
    logger.info("[var][identifiers] Generating vectors")
    process_datapoints(train_var_df,
                       os.path.join(output_path, "vectors", "train"),
                       'identifiers_', 'var_train', id_trans_func_var)
    process_datapoints(valid_var_df,
                       os.path.join(output_path, "vectors", "valid"),
                       'identifiers_', 'var_valid', id_trans_func_var)
    process_datapoints(test_var_df, os.path.join(output_path, "vectors",
                                                 "test"), 'identifiers_',
                       'var_test', id_trans_func_var)

    # Tokens
    logger.info("[var][code tokens] Generating vectors")
    process_datapoints(train_var_df,
                       os.path.join(output_path, "vectors", "train"),
                       'tokens_', 'var_train', token_trans_func_var)
    process_datapoints(valid_var_df,
                       os.path.join(output_path, "vectors", "valid"),
                       'tokens_', 'var_valid', token_trans_func_var)
    process_datapoints(test_var_df, os.path.join(output_path, "vectors",
                                                 "test"), 'tokens_',
                       'var_test', token_trans_func_var)

    # Generate data points for visible type hints
    logger.info("[visible type hints] Generating vectors")
    gen_aval_types_datapoints(train_param_df, train_return_df, train_var_df,
                              'train',
                              os.path.join(output_path, "vectors", "train"))
    gen_aval_types_datapoints(valid_param_df, valid_return_df, valid_var_df,
                              'valid',
                              os.path.join(output_path, "vectors", "valid"))
    gen_aval_types_datapoints(test_param_df, test_return_df, test_var_df,
                              'test',
                              os.path.join(output_path, "vectors", "test"))

    # a flattened vector for labels
    logger.info("[true labels] Generating vectors")
    gen_labels_vector(train_param_df, train_return_df, train_var_df, 'train',
                      os.path.join(output_path, "vectors", "train"))
    gen_labels_vector(valid_param_df, valid_return_df, valid_var_df, 'valid',
                      os.path.join(output_path, "vectors", "valid"))
    gen_labels_vector(test_param_df, test_return_df, test_var_df, 'test',
                      os.path.join(output_path, "vectors", "test"))
コード例 #17
0
 def _report_errors(self, parsed_result):
     logger.info(
         f"Produced {parsed_result.no_type_errs} type error(s) in {parsed_result.no_files} file(s)."
     )
     if parsed_result.err_breakdown:
         logger.info(f"Error breaking down: {parsed_result.err_breakdown}.")
コード例 #18
0
def filter_return_dp(df: pd.DataFrame) -> pd.DataFrame:
    """
    Filters return datapoints based on a set of criteria.
    """

    logger.info(f"Functions before dropping on return type {len(df):,}")
    df = df.dropna(subset=['return_type'])
    logger.info(f"Functions after dropping on return type {len(df):,}")

    logger.info(
        f"Functions before dropping nan, None, Any return type {len(df):,}")
    to_drop = np.invert((df['return_type'] == 'nan')
                        | (df['return_type'] == 'None')
                        | (df['return_type'] == 'Any'))
    df = df[to_drop]
    logger.info(f"Functions after dropping nan return type {len(df):,}")

    logger.info(
        f"Functions before dropping on empty return expression {len(df):,}")
    df = df[df['return_expr'].apply(lambda x: len(literal_eval(x))) > 0]
    logger.info(
        f"Functions after dropping on empty return expression {len(df):,}")

    return df
コード例 #19
0
def preprocess_ext_fns(output_dir: str, limit: int = None):
    """
    Applies preprocessing steps to the extracted functions
    """

    if not (os.path.exists(os.path.join(output_dir, "all_fns.csv"))
            and os.path.exists(os.path.join(output_dir, "all_vars.csv"))):
        logger.info("Merging JSON projects")
        merged_jsons = merge_jsons_to_dict(
            list_files(os.path.join(output_dir, 'processed_projects'),
                       ".json"), limit)
        logger.info("Creating functions' Dataframe")
        create_dataframe_fns(output_dir, merged_jsons)
        logger.info("Creating variables' Dataframe")
        create_dataframe_vars(output_dir, merged_jsons)

    logger.info("Loading vars & fns Dataframe")
    processed_proj_fns = pd.read_csv(os.path.join(output_dir, "all_fns.csv"),
                                     low_memory=False)
    processed_proj_vars = pd.read_csv(os.path.join(output_dir, "all_vars.csv"),
                                      low_memory=False)

    # Split the processed files into train, validation and test sets
    if all(processed_proj_fns['set'].isin(['train', 'valid', 'test'])) and \
       all(processed_proj_vars['set'].isin(['train', 'valid', 'test'])):
        logger.info("Found the sets split in the input dataset")
        train_files = processed_proj_fns['file'][processed_proj_fns['set'] ==
                                                 'train']
        valid_files = processed_proj_fns['file'][processed_proj_fns['set'] ==
                                                 'valid']
        test_files = processed_proj_fns['file'][processed_proj_fns['set'] ==
                                                'test']

        train_files_vars = processed_proj_vars['file'][
            processed_proj_vars['set'] == 'train']
        valid_files_vars = processed_proj_vars['file'][
            processed_proj_vars['set'] == 'valid']
        test_files_vars = processed_proj_vars['file'][
            processed_proj_vars['set'] == 'test']

    else:
        logger.info("Splitting sets randomly")
        uniq_files = np.unique(
            np.concatenate((processed_proj_fns['file'].to_numpy(),
                            processed_proj_vars['file'].to_numpy())))
        train_files, test_files = train_test_split(pd.DataFrame(
            uniq_files, columns=['file']),
                                                   test_size=0.2)
        train_files, valid_files = train_test_split(pd.DataFrame(
            train_files, columns=['file']),
                                                    test_size=0.1)
        train_files_vars, valid_files_vars, test_files_vars = train_files, valid_files, test_files

    df_train = processed_proj_fns[processed_proj_fns['file'].isin(
        train_files.to_numpy().flatten())]
    logger.info(f"No. of functions in train set: {df_train.shape[0]:,}")
    df_valid = processed_proj_fns[processed_proj_fns['file'].isin(
        valid_files.to_numpy().flatten())]
    logger.info(f"No. of functions in validation set: {df_valid.shape[0]:,}")
    df_test = processed_proj_fns[processed_proj_fns['file'].isin(
        test_files.to_numpy().flatten())]
    logger.info(f"No. of functions in test set: {df_test.shape[0]:,}")

    df_var_train = processed_proj_vars[processed_proj_vars['file'].isin(
        train_files_vars.to_numpy().flatten())]
    logger.info(f"No. of variables in train set: {df_var_train.shape[0]:,}")
    df_var_valid = processed_proj_vars[processed_proj_vars['file'].isin(
        valid_files_vars.to_numpy().flatten())]
    logger.info(
        f"No. of variables in validation set: {df_var_valid.shape[0]:,}")
    df_var_test = processed_proj_vars[processed_proj_vars['file'].isin(
        test_files_vars.to_numpy().flatten())]
    logger.info(f"No. of variables in test set: {df_var_test.shape[0]:,}")

    assert list(
        set(df_train['file'].tolist()).intersection(
            set(df_test['file'].tolist()))) == []
    assert list(
        set(df_train['file'].tolist()).intersection(
            set(df_valid['file'].tolist()))) == []
    assert list(
        set(df_test['file'].tolist()).intersection(
            set(df_valid['file'].tolist()))) == []

    # Exclude variables without a type
    processed_proj_vars = filter_var_wo_type(processed_proj_vars)

    logger.info(f"Making type annotations consistent")
    # Makes type annotations consistent by removing `typing.`, `t.`, and `builtins` from a type.
    processed_proj_fns, processed_proj_vars = make_types_consistent(
        processed_proj_fns, processed_proj_vars)

    assert any([
        bool(regex.match(sub_regex, str(t)))
        for t in processed_proj_fns['return_type']
    ]) == False
    assert any([
        bool(regex.match(sub_regex, t))
        for t in processed_proj_fns['arg_types']
    ]) == False
    assert any([
        bool(regex.match(sub_regex, t))
        for t in processed_proj_vars['var_type']
    ]) == False

    # Filters variables with type Any or None
    processed_proj_vars = filter_variables(processed_proj_vars)

    # Filters trivial functions such as `__str__` and `__len__`
    processed_proj_fns = filter_functions(processed_proj_fns)

    # Extracts type hints for functions' arguments
    processed_proj_fns_params = gen_argument_df(processed_proj_fns)

    # Filters out functions: (1) without a return type (2) with the return type of Any or None (3) without a return expression
    processed_proj_fns = filter_return_dp(processed_proj_fns)
    processed_proj_fns = format_df(processed_proj_fns)

    logger.info(f"Resolving type aliases")
    # Resolves type aliasing and mappings. e.g. `[]` -> `list`
    processed_proj_fns_params, processed_proj_fns, processed_proj_vars = resolve_type_aliasing(
        processed_proj_fns_params, processed_proj_fns, processed_proj_vars)

    assert any([
        bool(regex.match(r'^{}$|\bText\b|^\[{}\]$|^\[\]$', t))
        for t in processed_proj_fns['return_type']
    ]) == False
    assert any([
        bool(regex.match(r'^{}$|\bText\b|^\[\]$', t))
        for t in processed_proj_fns_params['arg_type']
    ]) == False

    logger.info(f"Preproceessing parametric types")
    processed_proj_fns_params, processed_proj_fns, processed_proj_vars = preprocess_parametric_types(
        processed_proj_fns_params, processed_proj_fns, processed_proj_vars)
    # Exclude variables without a type
    processed_proj_vars = filter_var_wo_type(processed_proj_vars)

    processed_proj_fns, processed_proj_fns_params, le_all = encode_all_types(
        processed_proj_fns, processed_proj_fns_params, processed_proj_vars,
        output_dir)

    # Exclude self from arg names and return expressions
    processed_proj_fns['arg_names_str'] = processed_proj_fns[
        'arg_names'].apply(lambda l: " ".join([v for v in l if v != 'self']))
    processed_proj_fns['return_expr_str'] = processed_proj_fns[
        'return_expr'].apply(
            lambda l: " ".join([regex.sub(r"self\.?", '', v) for v in l]))

    # Drop all columns useless for the ML model
    processed_proj_fns = processed_proj_fns.drop(columns=[
        'author', 'repo', 'has_type', 'arg_names', 'arg_types', 'arg_descrs',
        'args_occur', 'return_expr'
    ])

    # Visible type hints
    if exists(join(output_dir, 'MT4Py_VTHs.csv')):
        logger.info("Using visible type hints")
        processed_proj_fns_params, processed_proj_fns = encode_aval_types(
            processed_proj_fns_params, processed_proj_fns, processed_proj_vars,
            pd.read_csv(join(output_dir,
                             'MT4Py_VTHs.csv')).head(AVAILABLE_TYPES_NUMBER))
    else:
        logger.info("Using naive available type hints")
        df_types = gen_most_frequent_avl_types(
            os.path.join(output_dir, "extracted_visible_types"), output_dir,
            AVAILABLE_TYPES_NUMBER)
        processed_proj_fns_params, processed_proj_fns = encode_aval_types(
            processed_proj_fns_params, processed_proj_fns, processed_proj_vars,
            df_types)

    # Split parameters and returns type dataset by file into a train and test sets
    df_params_train = processed_proj_fns_params[
        processed_proj_fns_params['file'].isin(
            train_files.to_numpy().flatten())]
    df_params_valid = processed_proj_fns_params[
        processed_proj_fns_params['file'].isin(
            valid_files.to_numpy().flatten())]
    df_params_test = processed_proj_fns_params[
        processed_proj_fns_params['file'].isin(
            test_files.to_numpy().flatten())]

    df_ret_train = processed_proj_fns[processed_proj_fns['file'].isin(
        train_files.to_numpy().flatten())]
    df_ret_valid = processed_proj_fns[processed_proj_fns['file'].isin(
        valid_files.to_numpy().flatten())]
    df_ret_test = processed_proj_fns[processed_proj_fns['file'].isin(
        test_files.to_numpy().flatten())]

    df_var_train = processed_proj_vars[processed_proj_vars['file'].isin(
        train_files_vars.to_numpy().flatten())]
    df_var_valid = processed_proj_vars[processed_proj_vars['file'].isin(
        valid_files_vars.to_numpy().flatten())]
    df_var_test = processed_proj_vars[processed_proj_vars['file'].isin(
        test_files_vars.to_numpy().flatten())]

    assert list(
        set(df_params_train['file'].tolist()).intersection(
            set(df_params_test['file'].tolist()))) == []
    assert list(
        set(df_params_train['file'].tolist()).intersection(
            set(df_params_valid['file'].tolist()))) == []
    assert list(
        set(df_params_test['file'].tolist()).intersection(
            set(df_params_valid['file'].tolist()))) == []

    assert list(
        set(df_ret_train['file'].tolist()).intersection(
            set(df_ret_test['file'].tolist()))) == []
    assert list(
        set(df_ret_train['file'].tolist()).intersection(
            set(df_ret_valid['file'].tolist()))) == []
    assert list(
        set(df_ret_test['file'].tolist()).intersection(
            set(df_ret_valid['file'].tolist()))) == []

    assert list(
        set(df_var_train['file'].tolist()).intersection(
            set(df_var_test['file'].tolist()))) == []
    assert list(
        set(df_var_train['file'].tolist()).intersection(
            set(df_var_valid['file'].tolist()))) == []
    assert list(
        set(df_var_test['file'].tolist()).intersection(
            set(df_var_valid['file'].tolist()))) == []

    # Store the dataframes and the label encoders
    logger.info("Saving preprocessed functions on the disk...")
    with open(os.path.join(output_dir, "label_encoder_all.pkl"), 'wb') as file:
        pickle.dump(le_all, file)

    df_params_train.to_csv(os.path.join(output_dir, "_ml_param_train.csv"),
                           index=False)
    df_params_valid.to_csv(os.path.join(output_dir, "_ml_param_valid.csv"),
                           index=False)
    df_params_test.to_csv(os.path.join(output_dir, "_ml_param_test.csv"),
                          index=False)

    df_ret_train.to_csv(os.path.join(output_dir, "_ml_ret_train.csv"),
                        index=False)
    df_ret_valid.to_csv(os.path.join(output_dir, "_ml_ret_valid.csv"),
                        index=False)
    df_ret_test.to_csv(os.path.join(output_dir, "_ml_ret_test.csv"),
                       index=False)

    df_var_train.to_csv(os.path.join(output_dir, "_ml_var_train.csv"),
                        index=False)
    df_var_valid.to_csv(os.path.join(output_dir, "_ml_var_valid.csv"),
                        index=False)
    df_var_test.to_csv(os.path.join(output_dir, "_ml_var_test.csv"),
                       index=False)
コード例 #20
0
ファイル: data_loaders.py プロジェクト: saltudelft/type4py
def load_training_data_per_model(data_loading_funcs: dict, output_path: str,
                                 no_batches: int, train_mode:bool=True,
                                 no_workers:int=8) -> Tuple[DataLoader, DataLoader]:
    """
    Loads appropriate training data based on the model's type
    """

    # def find_common_types(y_all_train: torch.Tensor):
    #     count_types = Counter(y_all_train.data.numpy())
    #     return [t.item() for t in y_all_train if count_types[t.item()] >= 100]

    load_data_t = time()
    if data_loading_funcs['name'] == 'woi':
        # without identifiers
        X_tok_train, X_type_train = data_loading_funcs['train'](output_path)
        X_tok_valid, X_type_valid = data_loading_funcs['valid'](output_path)
        Y_all_train, Y_all_valid, _ = data_loading_funcs['labels'](output_path)

        train_mask = select_data(Y_all_train, MIN_DATA_POINTS)
        X_tok_train, X_type_train, Y_all_train = X_tok_train[train_mask], \
             X_type_train[train_mask], Y_all_train[train_mask]

        valid_mask = select_data(Y_all_valid, MIN_DATA_POINTS)
        X_tok_valid, X_type_valid, Y_all_valid = X_tok_valid[valid_mask], \
             X_type_valid[valid_mask], Y_all_valid[valid_mask]

        triplet_data_train = TripletDataset(X_tok_train, X_type_train, labels=Y_all_train,
                                      dataset_name=data_loading_funcs['name'], train_mode=train_mode)
        triplet_data_valid = TripletDataset(X_tok_valid, X_type_valid, labels=Y_all_valid,
                                            dataset_name=data_loading_funcs['name'],
                                            train_mode=train_mode)
    
    elif data_loading_funcs['name'] == 'woc':
        # without code tokens
        X_id_train, X_type_train = data_loading_funcs['train'](output_path)
        X_id_valid, X_type_valid = data_loading_funcs['valid'](output_path)
        Y_all_train, Y_all_valid, _ = data_loading_funcs['labels'](output_path)

        train_mask = select_data(Y_all_train, MIN_DATA_POINTS)
        X_id_train, X_type_train, Y_all_train = X_id_train[train_mask], \
                    X_type_train[train_mask], Y_all_train[train_mask]

        valid_mask = select_data(Y_all_valid, MIN_DATA_POINTS)
        X_id_valid, X_type_valid, Y_all_valid = X_id_valid[valid_mask], \
                X_type_valid[valid_mask], Y_all_valid[valid_mask]

        triplet_data_train = TripletDataset(X_id_train, X_type_train, labels=Y_all_train,
                                      dataset_name=data_loading_funcs['name'], train_mode=train_mode)
        triplet_data_valid = TripletDataset(X_id_valid, X_type_valid, labels=Y_all_valid,
                                            dataset_name=data_loading_funcs['name'],
                                            train_mode=train_mode)

    elif data_loading_funcs['name'] == 'wov':
        # without visible type hints
        X_id_train, X_tok_train, = data_loading_funcs['train'](output_path)
        X_id_valid, X_tok_valid, = data_loading_funcs['valid'](output_path)
        Y_all_train, Y_all_valid, _ = data_loading_funcs['labels'](output_path)

        train_mask = select_data(Y_all_train, MIN_DATA_POINTS)
        X_id_train, X_tok_train, Y_all_train = X_id_train[train_mask], \
                    X_tok_train[train_mask], Y_all_train[train_mask]

        valid_mask = select_data(Y_all_valid, MIN_DATA_POINTS)
        X_id_valid, X_tok_valid, Y_all_valid = X_id_valid[valid_mask], \
                    X_tok_valid[valid_mask], Y_all_valid[valid_mask]

        triplet_data_train = TripletDataset(X_id_train, X_tok_train, labels=Y_all_train,
                                      dataset_name=data_loading_funcs['name'], train_mode=train_mode)
        triplet_data_valid = TripletDataset(X_id_valid, X_tok_valid, labels=Y_all_valid,
                                            dataset_name=data_loading_funcs['name'],
                                            train_mode=train_mode)
        
    else:
        # Complete model
        X_id_train, X_tok_train, X_type_train = data_loading_funcs['train'](output_path)
        X_id_valid, X_tok_valid, X_type_valid = data_loading_funcs['valid'](output_path)
        Y_all_train, Y_all_valid, _ = data_loading_funcs['labels'](output_path)

        train_mask = select_data(Y_all_train, MIN_DATA_POINTS)
        X_id_train, X_tok_train, X_type_train, Y_all_train = X_id_train[train_mask], \
                    X_tok_train[train_mask], X_type_train[train_mask], Y_all_train[train_mask]

        valid_mask = select_data(Y_all_valid, MIN_DATA_POINTS)
        X_id_valid, X_tok_valid, X_type_valid, Y_all_valid = X_id_valid[valid_mask], \
                    X_tok_valid[valid_mask], X_type_valid[valid_mask], Y_all_valid[valid_mask]

        triplet_data_train = TripletDataset(X_id_train, X_tok_train, X_type_train, labels=Y_all_train,
                                      dataset_name=data_loading_funcs['name'], train_mode=train_mode)
        triplet_data_valid = TripletDataset(X_id_valid, X_tok_valid, X_type_valid, labels=Y_all_valid,
                                            dataset_name=data_loading_funcs['name'],
                                            train_mode=train_mode)


    logger.info(f"Loaded train and valid sets of the {data_loading_funcs['name']} dataset in {(time()-load_data_t)/60:.2f} min")

    train_loader = DataLoader(triplet_data_train, batch_size=no_batches, shuffle=True,
                              pin_memory=True, num_workers=no_workers)
    valid_loader = DataLoader(triplet_data_valid, batch_size=no_batches, num_workers=no_workers)

    return train_loader, valid_loader