示例#1
0
    # correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
    # accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    # Initialize the variables (i.e. assign their default value)
    init = tf.global_variables_initializer()

    ## data IO
    if cfg.use_xyz:
        all_video_data = pickle.load(open('./data/exp_1_xyz.p', 'rb'))
        # all_video_data = pickle.load(open('./data/exp_2_xyz.p','rb'))
    elif cfg.use_cos_sin:
        all_video_data = pickle.load(open('./data/exp_2_raw_pair.p', 'rb'))
    else:
        all_video_data = pickle.load(open('./data/exp_2_raw.p', 'rb'))

    datadb = clip_xyz(all_video_data)
    data_io = DataLayer(datadb, random=False, is_test=is_test)

    if not is_test:
        # Start training
        with tf.Session() as sess:
            # Run the initializer
            sess.run(init)
            summary_writer = tf.summary.FileWriter('./tfsummary/', sess.graph)
            total_batch = 8 * int(datadb[0]['x'].shape[1] /
                                  cfg.running_length / fps / batch_size)
            for epoch in range(training_epochs):
                avg_cost = 0.
                for step in range(1, total_batch):
                    # print('step',step)
                    batch_x, batch_y = data_io._get_next_minibatch(
示例#2
0
    else:
        decoder_states, state_h, state_c = lstm(this_inputs,
                                         initial_state=states)
    outputs = output_dense(decoder_states)
    all_outputs.append(expand_dim_layer(outputs))
    # this_inputs = outputs
    states = [state_h, state_c]

all_outputs = Lambda(lambda x: K.concatenate(x, axis=1))(all_outputs)

model = Model(inputs, all_outputs)
model.compile(optimizer='Adam', loss='mean_squared_error',metrics=['accuracy'])

#### ========================================data============================================================
video_data_train = pickle.load(open('./360video/data/shanghai_dataset_xyz_train.p','rb'),encoding='latin1')    
video_data_train = clip_xyz(video_data_train)
datadb = video_data_train.copy()
# assert cfg.data_chunk_stride=1
_video_db,_video_db_future,_video_db_future_input = util.get_data(datadb,pick_user=False,num_user=34)

if cfg.input_mean_var:
    input_data = util.get_gt_target_xyz(_video_db_future_input)
else:
    input_data = _video_db_future_input
target_data = util.get_gt_target_xyz(_video_db_future)

### ====================Training====================
tag = 'single_LSTM_keras_sep24'
model_checkpoint = ModelCheckpoint(tag+'{epoch:02d}-{val_loss:.4f}.h5', monitor='val_loss', save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
                                 patience=3, min_lr=1e-6)
if cfg.use_xyz:
    all_video_data = pickle.load(open('./data/new_exp_'+str(experiment)+'_xyz.p','rb'))
    # all_video_data = pickle.load(open('./data/exp_'+str(experiment)+'_xyz.p','rb'))
    data_dim = 3
elif cfg.use_yaw_pitch_roll:
    all_video_data = pickle.load(open('./data/exp_2_raw.p','rb'))
    data_dim = 2 #only use yaw and pitch
elif cfg.use_cos_sin:
    all_video_data = pickle.load(open('./data/exp_2_raw_pair.p','rb'))
    data_dim = 2
elif cfg.use_phi_theta:
    all_video_data = pickle.load(open('./data/exp_2_phi_theta.p','rb'))
    data_dim = 2
if cfg.process_in_seconds:
    data_dim = data_dim*fps
all_video_data = clip_xyz(all_video_data)
datadb = all_video_data.copy()


# ## start to train
# Graph def

# In[2]:

# tf Graph input
# if cfg.own_history_only:
#     x = tf.placeholder("float", [None, truncated_backprop_length, data_dim])
# else:
#     x = tf.placeholder("float", [None, truncated_backprop_length, 48*data_dim])
x = tf.placeholder("float", [None, truncated_backprop_length, data_dim])
if cfg.has_reconstruct_loss:
示例#4
0
def get_segment_index(datadb):
    """segment time is used to get the visual/saliency information"""
    #match in time!!!!
    if cfg.use_saliency:
        segment_index_tar = util.get_time_for_visual(datadb)
        segment_index_tar_future = OrderedDict()
        for key in segment_index_tar.keys():
            segment_index_tar_future[key] = np.array(segment_index_tar[key])+max_encoder_seq_length
    return segment_index_tar,segment_index_tar_future


#### ========================================data============================================================
if use_generator:
    video_data_train = pickle.load(open('./360video/data/shanghai_dataset_xyz_train.p','rb'))    
    datadb_train = clip_xyz(video_data_train)
    video_data_test = pickle.load(open('./360video/data/shanghai_dataset_xyz_test.p','rb'))    
    datadb_test = clip_xyz(video_data_test)

    #saliency
    if cfg.use_saliency:
        segment_index_tar,segment_index_tar_future = get_segment_index(datadb_train)
        mygenerator = generator_train2(datadb_train,segment_index_tar,segment_index_tar_future,phase='train')

        segment_index_tar_test,segment_index_tar_future_test = get_segment_index(datadb_test)
        mygenerator_val = generator_train2(datadb_test,segment_index_tar_test,segment_index_tar_future_test,phase='val')
    else:       
        # no saliency
        mygenerator = generator_train2(datadb_train,phase='train')
        mygenerator_val = generator_train2(datadb_test,phase='val')
示例#5
0
                epochs=starting_epoch + 1,
                validation_split=0.1,
                shuffle=cfg.shuffle_data,
                initial_epoch=starting_epoch,
                callbacks=[model_checkpoint, reduce_lr, stopping])

    ### ====================Testing====================
    video_data_test = pickle.load(open(
        './360video/data/shanghai_dataset_xyz_test.p', 'rb'),
                                  encoding='latin1')
    thu_tag = ''
    ### data format 5
    # video_data_test = pickle.load(open('./360video/temp/tsinghua_after_bysec_interpolation/tsinghua_test_video_data_over_video.p','rb'),encoding='latin1')
    # thu_tag='_thu_'

    video_data_test = clip_xyz(video_data_test)
    datadb = video_data_test.copy()
    _video_db, _video_db_future, _video_db_future_input = util.get_data(
        datadb, pick_user=False)
    if cfg.input_mean_var:
        _video_db = util.get_gt_target_xyz(_video_db)

    def decode_sequence_fov(input_seq):
        last_location = input_seq[0, -1, :][np.newaxis, np.newaxis, :]
        if cfg.input_mean_var:
            last_mu_var = last_location
        else:
            last_mu_var = util.get_gt_target_xyz(last_location)
        # last_mu_var = np.zeros_like(last_mu_var) ##zero out the decoder input
        decoded_sentence = vae.predict([input_seq, last_mu_var])
        return decoded_sentence