Пример #1
0
class ColSlicingBlock(object):
    """
    Parameters
    ----------
    W : Matrix (GpuMatrix or CpuMatrix)
    col_indexes

    """
    def __init__(self, W, col_indexes):
        device_id = W.device_id
        self.context = Context(device_id)
        learning = W.bpropagable
        if learning:
            self.W, self.dL_dW = W.register_usage_with_sparse_backward_matrix()
        else:
            self.W = W.register_usage(device_id)
        self.col_indexes = col_indexes.register_usage(device_id)
        output = Matrix.empty(W.nrows, col_indexes.ncols, device_id=device_id)
        self.output = Connector(output, device_id if learning else None)

    def fprop(self):
        self.W.slice_columns(self.context, self.col_indexes, self.output)
        self.output.fprop()

    def bprop(self):
        if hasattr(self, 'dL_dW'):
            self.dL_dW.add_columns_slice(self.col_indexes, self.output.bprop())
Пример #2
0
class RowSlicingBlock(object):
    def __init__(self, W, row_indexes, dense=True):
        self.dense = dense
        device_id = W.device_id
        self.context = Context(device_id)
        learning = W.bpropagable
        if learning:
            if dense:
                self.W, self.dL_dW = W.register_usage(device_id, device_id)
            else:
                self.W, self.dL_dW = W.register_usage_with_sparse_backward_matrix()
        else:
            self.W = W.register_usage(device_id)
        self.row_indexes = row_indexes.register_usage(device_id)
        if row_indexes.ncols > 1:
            self.output = []
            for i in xrange(row_indexes.ncols):
                output = Matrix.empty(row_indexes.nrows, W.ncols, device_id=device_id)
                output = Connector(output, device_id if learning else None)
                self.output.append(output)
            self.output = List(self.output, row_indexes.ncols)
        else:
            output = Matrix.empty(row_indexes.nrows, W.ncols, device_id=device_id)
            self.output = Connector(output, device_id if learning else None)

    def fprop(self):
        if isinstance(self.output, List):
            self.W.slice_rows_batch(self.context, self.row_indexes, self.output)
        else:
            self.W.slice_rows(self.context, self.row_indexes, self.output)
        self.output.fprop()

    def bprop(self):
        if hasattr(self, 'dL_dW'):
            if isinstance(self.output, List):
                update_method = self.dL_dW.add_rows_batch_slice
            else:
                update_method = self.dL_dW.add_rows_slice
            if self.dense:
                update_method(self.context, self.row_indexes, self.output.bprop())
            else:
                update_method(self.row_indexes, self.output.bprop())