Exemple #1
0
def assert_model_trains(input_tmaps: List[TensorMap],
                        output_tmaps: List[TensorMap],
                        m: Optional[tf.keras.Model] = None,
                        skip_shape_check: bool = False):
    if m is None:
        m = make_multimodal_multitask_model(
            input_tmaps,
            output_tmaps,
            **DEFAULT_PARAMS,
        )
    if not skip_shape_check:
        for tmap, tensor in zip(input_tmaps, m.inputs):
            assert tensor.shape[1:] == tmap.shape
            assert tensor.shape[1:] == tmap.shape
        for tmap, tensor in zip(parent_sort(output_tmaps), m.outputs):
            assert tensor.shape[1:] == tmap.shape
            assert tensor.shape[1:] == tmap.shape
    data = make_training_data(input_tmaps, output_tmaps)
    history = m.fit(data,
                    steps_per_epoch=2,
                    epochs=2,
                    validation_data=data,
                    validation_steps=2)
    for tmap in output_tmaps:
        for metric in tmap.metrics:
            metric_name = metric if type(metric) == str else metric.__name__
            name = f'{tmap.output_name()}_{metric_name}' if len(
                output_tmaps) > 1 else metric_name
            assert name in history.history
Exemple #2
0
def _process_args(args):
    now_string = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')
    args_file = os.path.join(args.output_folder, args.id, 'arguments_' + now_string + '.txt')
    command_line = f"\n./scripts/tf.sh {' '.join(sys.argv)}\n"
    if not os.path.exists(os.path.dirname(args_file)):
        os.makedirs(os.path.dirname(args_file))
    with open(args_file, 'w') as f:
        f.write(command_line)
        for k, v in sorted(args.__dict__.items(), key=operator.itemgetter(0)):
            f.write(k + ' = ' + str(v) + '\n')
    load_config(args.logging_level, os.path.join(args.output_folder, args.id), 'log_' + now_string, args.min_sample_id)
    args.u_connect = _process_u_connect_args(args.u_connect, args.tensormap_prefix)
    args.pairs = _process_pair_args(args.pairs, args.tensormap_prefix)

    args.tensor_maps_in = []
    args.tensor_maps_out = []
    if args.text_file is not None:
        del args.input_tensors[:2]
        del args.output_tensors[0]
        input_map, burn_in, output_map = generate_random_text_tensor_maps(args.text_file, args.text_window, args.text_one_hot)
        if args.text_one_hot:
            args.tensor_maps_in.append(input_map)
        else:
            args.tensor_maps_in.extend([input_map, burn_in])
        args.tensor_maps_out.append(output_map)

    args.tensor_maps_in.extend([tensormap_lookup(it, args.tensormap_prefix) for it in args.input_tensors])

    if args.continuous_file is not None:
        # Continuous TensorMap generated from file is given the name specified by the first output_tensors argument
        args.tensor_maps_out.append(
            generate_continuous_tensor_map_from_file(
                args.continuous_file,
                args.continuous_file_column,
                args.output_tensors.pop(0),
                args.continuous_file_normalize,
                args.continuous_file_discretization_bounds,
            ),
        )

    args.tensor_maps_out.extend([tensormap_lookup(ot, args.tensormap_prefix) for ot in args.output_tensors])
    args.tensor_maps_out = parent_sort(args.tensor_maps_out)
    args.tensor_maps_protected = [tensormap_lookup(it, args.tensormap_prefix) for it in args.protected_tensors]
    args.sample_weight = tensormap_lookup(args.sample_weight, args.tensormap_prefix) if args.sample_weight else None
    if args.sample_weight:
        assert args.sample_weight.shape == (1,)

    args.bottleneck_type = BOTTLENECK_STR_TO_ENUM[args.bottleneck_type]
    if args.bottleneck_type == BottleneckType.NoBottleNeck:
        check_no_bottleneck(args.u_connect, args.tensor_maps_out)

    if args.learning_rate_schedule is not None and args.patience < args.epochs:
        raise ValueError(f'learning_rate_schedule is not compatible with ReduceLROnPlateau. Set patience > epochs.')

    np.random.seed(args.random_seed)

    logging.info(f"Command Line was: {command_line}")
    logging.info(f"Arguments are {args}\n")

    if args.eager:
        import tensorflow as tf
        tf.config.experimental_run_functions_eagerly(True)
Exemple #3
0
def test_parent_sort_cycle(tmaps):
    with pytest.raises(ValueError):
        parent_sort(tmaps)
Exemple #4
0
def test_parent_sort_idempotent(tmaps):
    assert parent_sort(tmaps) == parent_sort(
        parent_sort(tmaps)) == parent_sort(parent_sort(parent_sort(tmaps)))
Exemple #5
0
def test_parent_sort(tmaps):
    assert parent_sort(tmaps) == PARENT_TMAPS