Example #1
0

parser = parsers.test_parser
args = parser.parse_args()

path = args.input
batch_size = args.batch_size
buffer_size = batch_size * 5
regex = r"joint_t0__joint_t1_[0-9]+_[0-9]+_[0-9]+_[0-9]+_[0-9]+_[0-9]+_nd([0-9]+)"
dirname = os.path.basename(os.path.normpath(args.network_path))
N_DISCRETE = int(re.match(regex, dirname).group(1))
regex = r"sf[0-9]+.?[0-9]*_re[0-9]+_ae[0-9]+_n([0-9]+)_chunk[0-9]+.tfr"
n_records = int(re.match(regex, os.listdir(path + '/positions')[0]).group(1))

dataset_t0 = database.get_dataset(path, positions=True, actions=True, vision=True)
dataset_t0 = dataset_t0.map(database.discretize_dataset(N_DISCRETE))
dataset_t1 = dataset_t0.skip(1)
dataset = tf.data.Dataset.zip((dataset_t0, dataset_t1))
dataset = dataset.batch(batch_size)

iterator = dataset.make_initializable_iterator()
batch_t0, batch_t1 = iterator.get_next()

discrete_positions = [tf.squeeze(x, axis=1) for x in tf.split(batch_t0["positions"], 4, axis=1)]
discrete_actions = [tf.squeeze(x, axis=1) for x in tf.split(batch_t0["actions"], 4, axis=1)]
discrete_positions_target = [tf.squeeze(x, axis=1) for x in tf.split(batch_t1["positions"], 4, axis=1)]

joint_predictors = [predictor_maker(2 * N_DISCRETE, N_DISCRETE) for a in discrete_actions]
inps = [tf.concat([p, a], axis=1) for p, a in zip(discrete_positions, discrete_actions)]
outs = [joint_predictor(inp) for inp, joint_predictor in zip(inps, joint_predictors)]
losses = [mse(out, target, axis=-1) for out, target in zip(outs, discrete_positions_target)]
Example #2
0
parser = parsers.train_parser
parser.add_argument('-d',
                    '--n-discrete',
                    type=int,
                    default=60,
                    help="Discretization precision.")

args = parser.parse_args()

path = args.input
batch_size = args.batch_size
buffer_size = batch_size * 5
N_DISCRETE = args.n_discrete

dataset = database.get_dataset(path, positions=True, vision=True)
dataset = dataset.map(database.discretize_dataset(N_DISCRETE))
dataset = dataset.map(database.vision_to_float32)
dataset = dataset.prefetch(5 * batch_size)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat()

iterator = dataset.make_initializable_iterator()
batch = iterator.get_next()

size = np.prod(batch["vision"].get_shape().as_list()[1:])
vision = tf.reshape(batch["vision"], [-1, size])
discrete_positions = [
    tf.squeeze(x, axis=1) for x in tf.split(batch["positions"], 4, axis=1)
]