示例#1
0
    adam = optimizers.Adam(lr=0.005, decay=0.95)
    vae.compile(optimizer=adam, loss=get_vae_loss)
    vae.summary()

    #### ====================data====================
    video_data_train = pickle.load(open(
        './360video/data/shanghai_dataset_xyz_train.p', 'rb'),
                                   encoding='latin1')
    datadb = video_data_train.copy()
    _video_db, _video_db_future, _video_db_future_input = util.get_data(
        datadb, pick_user=False)
    #### only do reconstruction!!
    # _video_db_future = _video_db[:,::-1,:]

    if cfg.input_mean_var:
        encoder_input_data = util.get_gt_target_xyz(_video_db)
    else:
        encoder_input_data = _video_db
    decoder_target_data = util.get_gt_target_xyz(_video_db_future)
    decoder_input_data = util.get_gt_target_xyz(
        _video_db_future_input)[:, 0, :][:, np.newaxis, :]
    # decoder_input_data = util.get_gt_target_xyz(_video_db_future_input)
    # decoder_input_data = np.zeros_like(decoder_input_data) ##zero out the decoder input

    ### ====================Training====================
    # tag='vae_seq2seq_dec5'
    # tag='vae_seq2seq_fullTF_dec5'
    # tag='vae_seq2seq_vae20_dec5' #vae_latent_dim=20
    # tag='vae_seq2seq_vae100_dec5'
    # tag='vae_seq2seq_vae20_zerodecoderinput_dec5'
    # tag='vae_seq2seq_vae20_kl_logannealing_dec6'
    # num_testing_sample = int(0.15*total_num_samples)#use last few as test
    num_testing_sample = 1  #already pure train, don't have to save for test

    if cfg.shuffle_data:
        #shuffle the whole dataset
        index_shuf = pickle.load(
            open('index_shuf' + '_exp' + str(experiment) + '.p', 'rb'))
        _video_db = shuffle_data(index_shuf, _video_db)
        _video_db_future = shuffle_data(index_shuf, _video_db_future)
        _video_db_future_input = shuffle_data(index_shuf,
                                              _video_db_future_input)

    #prepare training data
    if cfg.input_mean_var:
        encoder_input_data = get_gt_target_xyz(
            _video_db[:-num_testing_sample, :, :].squeeze())[:, :, np.newaxis,
                                                             np.newaxis, :]
        decoder_input_data = get_gt_target_xyz(
            _video_db_future_input[:-num_testing_sample, :]
        )[:, 0, :][:, np.newaxis, np.newaxis, np.newaxis, :]
    else:
        encoder_input_data = _video_db[:-num_testing_sample, :, :]
        decoder_input_data = _video_db_future_input[:-num_testing_sample,
                                                    0, :][:, np.newaxis,
                                                          np.newaxis, :]

    if cfg.predict_mean_var:
        decoder_target_data = get_gt_target_xyz(
            _video_db_future)[:-num_testing_sample, :, :]
    else:
        decoder_target_data = _video_db_future[:
示例#3
0
_video_db_future_oth = _reshape_others_data(_video_db_future_oth)
_video_db_tar = _video_db_tar.reshape((_video_db_tar.shape[0],_video_db_tar.shape[1],fps,3))
_video_db_future_tar = _video_db_future_tar.reshape((_video_db_tar.shape[0],_video_db_tar.shape[1],fps,3))
_video_db_future_input_tar = _video_db_future_input_tar.reshape((_video_db_tar.shape[0],_video_db_tar.shape[1],fps,3))


print('other data shape: ',_video_db_oth.shape)
print('other data shape: ',_video_db_future_oth.shape)
print('target user data shape: ',_video_db_tar.shape)
print('target user data shape: ',_video_db_future_tar.shape)


#### prepare training data
if cfg.input_mean_var:
    ### target user
    encoder_input_data = util.get_gt_target_xyz(_video_db_tar)
    decoder_target_data = util.get_gt_target_xyz(_video_db_future_tar)
    # encoder_input_data = util.get_gt_target_xyz(_video_db_oth_all)
    # decoder_target_data = util.get_gt_target_xyz(_video_db_future_oth_all)
    # ### other users
    others_input_data = util.get_gt_target_xyz_oth(_video_db_future_oth)
    if not cfg.teacher_forcing:
        decoder_input_data = encoder_input_data[:,-1,:][:,np.newaxis,:]
    else:
        decoder_input_data = util.get_gt_target_xyz(_video_db_future_input_tar)

else:
    ### target user
    _video_db_tar = _video_db_tar.reshape((_video_db_tar.shape[0],_video_db_tar.shape[1],-1))
    encoder_input_data = _video_db_tar
    # decoder_target_data = _video_db_future_tar #predict raw
示例#4
0
        # index_shuf = util.get_shuffle_index(total_num_samples)
        index_shuf = pickle.load(open('index_shuf'+'_exp'+str(experiment)+'.p','rb'))
        print('Shuffle data before training and testing.')
        _video_db_tar = util.shuffle_data(index_shuf,_video_db_tar)
        _video_db_future_tar = util.shuffle_data(index_shuf,_video_db_future_tar)
        _video_db_future_input_tar = util.shuffle_data(index_shuf,_video_db_future_input_tar)

        # _video_db_oth = util.shuffle_data(index_shuf,_video_db_oth)
        _video_db_future_oth = util.shuffle_data(index_shuf,_video_db_future_oth)
        # _video_db_future_input_oth = util.shuffle_data(index_shuf,_video_db_future_input_oth)


    #### prepare training data
    if cfg.input_mean_var:
        ### target user
        encoder_input_data = util.get_gt_target_xyz(_video_db_tar[:-num_testing_sample,:,:])
        decoder_target_data = util.get_gt_target_xyz(_video_db_future_tar[:-num_testing_sample,:,:])
        # encoder_input_data = util.get_gt_target_xyz(_video_db_oth_all[:-num_testing_sample,:,:])
        # decoder_target_data = util.get_gt_target_xyz(_video_db_future_oth_all[:-num_testing_sample,:,:])
        # ### other users
        others_fut_input_data = util.get_gt_target_xyz_oth(_video_db_future_oth[:-num_testing_sample])
        if not cfg.teacher_forcing:
            decoder_input_data = encoder_input_data[:,-1,:][:,np.newaxis,:]
        else:
            decoder_input_data = util.get_gt_target_xyz(_video_db_future_input_tar[:-num_testing_sample,:,:])
            # decoder_input_data = util.get_gt_target_xyz(_video_db_future_input_oth_all[:-num_testing_sample,:,:])

    else:
        ### target user
        _video_db_tar = _video_db_tar.reshape((_video_db_tar.shape[0],_video_db_tar.shape[1],-1))
        encoder_input_data = _video_db_tar[:-num_testing_sample,:,:]
 def decode_sequence_fov(input_seq,others_fut_input_seq):
     # Encode the input as state vectors.
     last_location = input_seq[0,-1,:][np.newaxis,np.newaxis,:]
     last_mu_var = util.get_gt_target_xyz(last_location)
     decoded_sentence = model.predict([input_seq,others_fut_input_seq,last_mu_var])
     return decoded_sentence
    #     #shuffle the whole dataset
    #     index_shuf = util.get_shuffle_index(_video_db.shape[0])
    #     # index_shuf = pickle.load(open('index_shuf'+'_exp'+str(experiment)+'.p','rb'))
    #     _video_db = util.shuffle_data(index_shuf,_video_db)
    #     _video_db_future = util.shuffle_data(index_shuf,_video_db_future)
    #     _video_db_future_input = util.shuffle_data(index_shuf,_video_db_future_input)

    if cfg.rescale_input:
        _video_db = util._rescale_data(_video_db)
        _video_db_future = util._rescale_data(_video_db_future)
        if not cfg.enc_last_out_as_dec_in:
            _video_db_future_input = util._rescale_data(_video_db_future_input)


    if cfg.input_mean_var:
        encoder_input_data = util.get_gt_target_xyz(_video_db)
    else:
        encoder_input_data = _video_db
    if cfg.predict_mean_var:
        decoder_target_data = util.get_gt_target_xyz(_video_db_future)
    else:
        decoder_target_data = _video_db_future

    if not cfg.enc_last_out_as_dec_in:
        if cfg.predict_mean_var:
            decoder_input_data = util.get_gt_target_xyz(_video_db_future_input)[:,0,:][:,np.newaxis,:]
        else:
            decoder_input_data = _video_db_future_input[:,0,:][:,np.newaxis,:]


    # overwrite decoder_input_data using all zeros!!!
示例#7
0
        decoded_sentence += [output_tokens]
        target_seq = output_tokens
        states_value = [h1, c1, h2, c2]

    return decoded_sentence


gt_sentence_list = []
decoded_sentence_list = []
for seq_index in range(0, _video_db_tar.shape[0]):
    # for seq_index in range(total_num_samples-num_testing_sample,total_num_samples):
    # for seq_index in range(total_num_samples-num_testing_sample,total_num_samples-num_testing_sample+100):
    # input_seq = _video_db_tar[seq_index: seq_index + 1,:,:]
    # others_fut_input_seq = _video_db_future_oth[seq_index: seq_index + 1,:]
    if cfg.input_mean_var:
        input_seq = util.get_gt_target_xyz(_video_db_tar[seq_index:seq_index +
                                                         1, :, :])
    else:
        input_seq = _video_db_tar[seq_index:seq_index + 1, :]
    others_fut_input_seq = util.get_gt_target_xyz_oth(
        _video_db_future_oth[seq_index:seq_index + 1, :])

    if cfg.teacher_forcing:
        decoded_sentence = decode_sequence_fov_TF(input_seq,
                                                  others_fut_input_seq)
    else:
        decoded_sentence = decode_sequence_fov(input_seq, others_fut_input_seq)
    decoded_sentence_list += [decoded_sentence]
    gt_sentence = _video_db_future_tar[seq_index:seq_index + 1, :, :]
    gt_sentence_list += [gt_sentence]
    # print('-')
    # decoder_target = util.get_gt_target_xyz(gt_sentence)
def _prepare_data(per_video_db_tar,
                  per_video_db_future_tar,
                  per_video_db_future_input_tar,
                  per_video_db_oth,
                  per_video_db_future_oth,
                  per_video_db_future_input_oth,
                  phase='train'):
    def _reshape_others_data(_video_db_oth):
        ## to match Input shape: others_fut_inputs
        _video_db_oth = _video_db_oth.transpose((1, 2, 0, 3))
        _video_db_oth = _video_db_oth.reshape(
            (_video_db_oth.shape[0], _video_db_oth.shape[1],
             _video_db_oth.shape[2], fps, 3))
        return _video_db_oth

    per_video_db_oth = _reshape_others_data(per_video_db_oth)
    per_video_db_future_oth = _reshape_others_data(per_video_db_future_oth)
    per_video_db_tar = per_video_db_tar.reshape(
        (per_video_db_tar.shape[0], per_video_db_tar.shape[1], fps, 3))
    per_video_db_future_tar = per_video_db_future_tar.reshape(
        (per_video_db_tar.shape[0], per_video_db_tar.shape[1], fps, 3))
    per_video_db_future_input_tar = per_video_db_future_input_tar.reshape(
        (per_video_db_tar.shape[0], per_video_db_tar.shape[1], fps, 3))

    # print('other data shape: ',per_video_db_oth.shape)
    # print('other data shape: ',per_video_db_future_oth.shape)
    # print('target user data shape: ',per_video_db_tar.shape)
    # print('target user data shape: ',per_video_db_future_tar.shape)

    if cfg.input_mean_var:
        ### target user
        encoder_input_data = util.get_gt_target_xyz(per_video_db_tar)
        # ### other users
        others_fut_input_data = util.get_gt_target_xyz_oth(
            per_video_db_future_oth)
        if not cfg.teacher_forcing:
            decoder_input_data = encoder_input_data[:, -1, :][:, np.newaxis, :]
        else:
            decoder_input_data = util.get_gt_target_xyz(
                per_video_db_future_input_tar)
    else:
        ### target user
        encoder_input_data = per_video_db_tar.reshape(
            (per_video_db_tar.shape[0], per_video_db_tar.shape[1], -1))
        decoder_input_data = encoder_input_data[:, -1, :][:, np.newaxis, :]
        # decoder_input_data = util.get_gt_target_xyz(encoder_input_data[:,-1,:][:,np.newaxis,:])
        ### other users
        others_fut_input_data = per_video_db_future_oth.transpose(
            (0, 1, 3, 2, 4))
        others_fut_input_data = others_fut_input_data.reshape(
            (others_fut_input_data.shape[0], others_fut_input_data.shape[1],
             others_fut_input_data.shape[2], -1))
        # others_fut_input_data = util.get_gt_target_xyz_oth(per_video_db_future_oth)

    if cfg.predict_mean_var:
        decoder_target_data = util.get_gt_target_xyz(
            per_video_db_future_tar)  #predict mean/var
    else:
        # decoder_target_data = per_video_db_future_tar[:,:,np.newaxis,:,:]#predict raw
        decoder_target_data = per_video_db_future_tar.reshape(
            (per_video_db_future_tar.shape[0],
             per_video_db_future_tar.shape[1], -1))  #predict raw
    if phase == 'test':
        decoder_target_data = per_video_db_future_tar[:, :, np.newaxis, :, :]

    # print('encoder_input_data shape: ',encoder_input_data.shape)
    # print('decoder_target_data shape: ',decoder_target_data.shape)
    # print('decoder_input_data shape: ',decoder_input_data.shape)
    # print('others_fut_input_data shape: ',others_fut_input_data.shape)
    return encoder_input_data, others_fut_input_data, decoder_input_data, decoder_target_data