示例#1
0
decode_predictions = config.decode_predictions
skip_pred_steps = config.skip_pred_steps
init_state_network = config.init_state_network
in_out_states = config.in_out_states
pred_gradient_loss = config.pred_gradient_loss
ls_prediction_loss = config.ls_prediction_loss
ls_supervision = config.ls_supervision
sqrd_diff_loss = config.sqrd_diff_loss
ls_split = config.ls_split

model_base_dir = find_model_base_dir(args.load_path)
data_args_path = None

if os.path.exists(os.path.join(model_base_dir, "data_args.txt")):
    data_args_path = os.path.join(model_base_dir, "data_args.txt")
    dataset_meta_info = read_args_file(data_args_path)
else:
    data_args_path = os.path.join(config.data_path, "args.txt")
    dataset_meta_info = read_args_file(data_args_path)

sup_param_count = max(
    1,
    int(dataset_meta_info['num_param']) -
    2)  # two parameters are always present -> scene num and frame num
res_x = int(dataset_meta_info["resolution_x"])
res_y = int(dataset_meta_info["resolution_y"])
res_z = int(dataset_meta_info["resolution_z"])

in_out_dim = 3 if "density" in config.data_type else 2
in_out_dim = in_out_dim + 1 if config.is_3d else in_out_dim
input_shape = (input_frame_count, )
# read config entries
input_frame_count = config.input_frame_count
prediction_window = config.w_num
decode_predictions = config.decode_predictions
skip_pred_steps = config.skip_pred_steps
init_state_network = config.init_state_network
in_out_states = config.in_out_states
pred_gradient_loss = config.pred_gradient_loss
ls_prediction_loss = config.ls_prediction_loss
ls_supervision = config.ls_supervision
sqrd_diff_loss = config.sqrd_diff_loss
ls_split = config.ls_split
train_prediction_only = config.train_prediction_only 

dataset_meta_info = read_args_file(os.path.join(config.data_path, 'args.txt'))
sup_param_count = max(1,int(dataset_meta_info['num_param']) - 2) # two parameters are always present -> scene num and frame num
res_x = int(dataset_meta_info["resolution_x"])
res_y = int(dataset_meta_info["resolution_y"])
res_z = int(dataset_meta_info["resolution_z"])

in_out_dim = 3 if "levelset" in config.data_type else 2
in_out_dim = in_out_dim + 1 if config.is_3d else in_out_dim
input_shape = (input_frame_count,)
input_shape += (res_z,) if config.is_3d else ()
input_shape += (res_y, res_x, in_out_dim)

if classic_ae:
	rec_pred = RecursivePredictionCleanSplit(config=config, input_shape=input_shape, decode_predictions=decode_predictions, skip_pred_steps=skip_pred_steps, init_state_network=init_state_network, in_out_states=in_out_states, pred_gradient_loss=pred_gradient_loss, ls_prediction_loss=ls_prediction_loss, ls_supervision=ls_supervision, sqrd_diff_loss=sqrd_diff_loss, ls_split=ls_split, supervised_parameters=sup_param_count, train_prediction_only=train_prediction_only)
else:
	rec_pred = RecursivePrediction(config=config, input_shape=input_shape, decode_predictions=decode_predictions, skip_pred_steps=skip_pred_steps, init_state_network=init_state_network, in_out_states=in_out_states, pred_gradient_loss=pred_gradient_loss, ls_prediction_loss=ls_prediction_loss, ls_supervision=ls_supervision, sqrd_diff_loss=sqrd_diff_loss, ls_split=ls_split, supervised_parameters=sup_param_count, train_prediction_only=train_prediction_only)
def initialize_networks(args, config, norm_factors):
    net = type('net', (), {})()
    net.norm_factors = norm_factors
    # read dataset meta information from data set directory
    net.dataset_meta_info = read_args_file(
        os.path.join(config.data_path, 'args.txt'))
    # extract important properties for model creation
    net.sup_param_count = max(
        1,
        int(net.dataset_meta_info['num_param']) -
        2)  # two parameters are always present -> scene num and frame num
    net.res_x = int(net.dataset_meta_info["resolution_x"])
    net.res_y = int(net.dataset_meta_info["resolution_y"])
    net.res_z = int(net.dataset_meta_info["resolution_z"])
    # prepare input shape
    in_out_dim = 3 if "density" in config.data_type else 2
    in_out_dim = in_out_dim + 1 if config.is_3d else in_out_dim
    net.input_shape = (config.input_frame_count, )
    net.input_shape += (net.res_z, ) if config.is_3d else ()
    net.input_shape += (net.res_y, net.res_x, in_out_dim)
    # create models
    net.classic_ae = args.classic_ae
    net.is_3d = net.res_z > 1
    if net.classic_ae:
        net.rec_pred = RecursivePredictionCleanSplit(
            config=config,
            input_shape=net.input_shape,
            decode_predictions=config.decode_predictions,
            skip_pred_steps=config.skip_pred_steps,
            init_state_network=config.init_state_network,
            in_out_states=config.in_out_states,
            pred_gradient_loss=config.pred_gradient_loss,
            ls_prediction_loss=config.ls_prediction_loss,
            ls_supervision=config.ls_supervision,
            sqrd_diff_loss=config.sqrd_diff_loss,
            ls_split=config.ls_split,
            supervised_parameters=net.sup_param_count)
    else:
        net.rec_pred = RecursivePrediction(
            config=config,
            input_shape=net.input_shape,
            decode_predictions=config.decode_predictions,
            skip_pred_steps=config.skip_pred_steps,
            init_state_network=config.init_state_network,
            in_out_states=config.in_out_states,
            pred_gradient_loss=config.pred_gradient_loss,
            ls_prediction_loss=config.ls_prediction_loss,
            ls_supervision=config.ls_supervision,
            sqrd_diff_loss=config.sqrd_diff_loss,
            ls_split=config.ls_split,
            supervised_parameters=net.sup_param_count)
    # load weights from file
    net.rec_pred.load_model(args.load_path)
    # create separate prediction model and copy over weights
    net.pred = Prediction(config=net.rec_pred.config,
                          input_shape=(net.rec_pred.w_num, net.rec_pred.z_num))
    net.pred._build_model()
    net.pred.model.set_weights(net.rec_pred.pred.model.get_weights())
    # create prediction history
    net.prediction_history = PredictionHistory(
        in_ts=net.rec_pred.w_num, data_shape=(net.rec_pred.z_num, ))
    return net