Exemple #1
0
"""Generic utilities for machine learning.
"""

import code
import os

import tensorflow as tf

from ccn.cfg import get_config
CFG = get_config()


def shuffle_together(*tensors):
    idxs = tf.range(tensors[0].shape[0])
    out = []
    for tensor in tensors:
        out.append(tf.gather(tensor, idxs))
    return tuple(out)


def update_data_dict(data_dict, batch_dict):
    for name in batch_dict.keys():
        if name not in data_dict:
            data_dict[name] = batch_dict[name]
        else:
            data_dict[name] += batch_dict[name]
    return data_dict


def normalize_data_dict(data_dict, num_batches):
    for name, value in data_dict.items():
"""Brute force graph matching with TensorFlow code. This works for small graphs.
(For graphs with >6 max nodes, it becomes pretty much unfeasible.)

TODO: Hungarian algorithm
"""

import code
import itertools

import numpy as np
import tensorflow as tf

from ccn.cfg import get_config; CFG = get_config()


permutations = {
  3: tf.convert_to_tensor(np.array(list(itertools.permutations(range(3))))),
  4: tf.convert_to_tensor(np.array(list(itertools.permutations(range(4))))),
  5: tf.convert_to_tensor(np.array(list(itertools.permutations(range(5))))),
  6: tf.convert_to_tensor(np.array(list(itertools.permutations(range(6))))),
  7: tf.convert_to_tensor(np.array(list(itertools.permutations(range(7))))),
  8: tf.convert_to_tensor(np.array(list(itertools.permutations(range(8))))),
}


def loss_fn(adj, nf, possible_adjs, possible_nfs):
  acc = {}
  permute_dim = possible_adjs.shape[1]

  # calculate losses along last axis (per node)
  lfn = tf.keras.losses.mean_squared_error if CFG['use_mse_loss'] else \