示例#1
0
    def load_data_to_buffer(self, data_file_range):

        if not self._load_pre_processd_data:

            tmp_x, self._y_buffer = get_data_from_files(
                data_file_range=data_file_range,
                model_type=self._params['model_type'])

            #if input is an image, then we need to convert the string to float
            if self._params['model_type'] == 'cnn' or self._params[
                    'model_type'] == 'siam':
                self._x_buffer = []
                for x_image in tmp_x:
                    if self._params['model_type'] == 'siam':
                        self._x_buffer.append(
                            (np.transpose(string2image(x_image[0][0]),
                                          axes=[2, 1, 0]).flatten(),
                             np.transpose(string2image(x_image[1][0]),
                                          axes=[2, 1, 0]).flatten()))
                    else:
                        self._x_buffer.append(
                            np.transpose(string2image(x_image[0]),
                                         axes=[2, 1, 0]).flatten())
            else:
                self._x_buffer = tmp_x

        else:
            self._x_buffer, self._y_buffer = self._get_pre_processd_data.load_data(
            )

        assert (len(self._x_buffer) == len(self._y_buffer))
        self._buffer_size = len(self._x_buffer)
示例#2
0
def get_data(operation):

    if operation == 'test':
        data_file_indices = network_params_fwd['test_file_indices']
    elif operation =='train':
        data_file_indices = network_params_fwd['train_file_indices']

    data_x, data_y = get_data_from_files(data_file_range=data_file_indices, model_type='fwd')

    return data_x, data_y
def get_data(operation, string_img_convert=True):

    if operation == 'test':
        data_file_indices = network_params_cmbnd['test_file_indices']
    elif operation == 'train':
        data_file_indices = network_params_cmbnd['train_file_indices']

    tmp_x, data_y = get_data_from_files(data_file_range=data_file_indices,
                                        model_type='cnn')

    data_x = []
    if string_img_convert:
        for x_image in tmp_x:
            data_x.append(string2image(x_image[0]).flatten())
    else:
        data_x = tmp_x

    return np.asarray(data_x), np.asarray(data_y)
示例#4
0
def get_data(operation, string_img_convert=True):

    raise Exception("Fix this damn thing")

    if operation == 'test':
        data_file_indices = network_params_siam['test_file_indices']
    elif operation =='train':
        data_file_indices = network_params_siam['train_file_indices']

    tmp_x, data_y = get_data_from_files(data_file_range=data_file_indices, model_type='siam')

    data_x = []
    if string_img_convert:
        for x_image in tmp_x:
            data_x.append((np.transpose(string2image(x_image[0][0]), axes=[2,1,0]).flatten(), 
                           np.transpose(string2image(x_image[1][0]), axes=[2,1,0]).flatten()))

    else:
        data_x = tmp_x

    return np.asarray(data_x), np.asarray(data_y)
示例#5
0
def get_data(data_file_indices, model_type):
    tmp_x, y_buffer = get_data_from_files(data_file_range=data_file_indices,
                                          model_type=model_type)

    #if input is an image, then we need to convert the string to float
    if model_type == 'cnn' or model_type == 'siam':
        x_buffer = []
        for x_image in tmp_x:
            if model_type == 'siam':
                x_buffer.append((np.transpose(string2image(x_image[0][0]),
                                              axes=[2, 1, 0]).flatten(),
                                 np.transpose(string2image(x_image[1][0]),
                                              axes=[2, 1, 0]).flatten()))
            else:
                x_buffer.append(
                    np.transpose(string2image(x_image[0]), axes=[2, 1,
                                                                 0]).flatten())
    else:
        x_buffer = tmp_x

    assert (len(x_buffer) == len(y_buffer))

    return x_buffer, y_buffer