def apply(self, fgraph): did_something = True while did_something: nodelist = fgraph.toposort() did_something = False for node in nodelist: if node.op == T._max_and_argmax: if len(node.outputs[1].clients)==0: try: axis=get_scalar_constant_value(node.inputs[1]) except NotScalarConstantError: return False new = CAReduce(scal.maximum,axis)(node.inputs[0]) try: fgraph.replace_all_validate( ((node.outputs[0],new),), reason = self.__class__.__name__) did_something = True break except InconsistencyError, e: pass
def make_node(self, x, repeats): x = basic.as_tensor_variable(x) repeats = basic.as_tensor_variable(repeats) if repeats.dtype not in tensor.discrete_dtypes: raise TypeError("repeats.dtype must be an integer.") # Some dtypes are not supported by numpy's implementation of repeat. # Until another one is available, we should fail at graph construction # time, not wait for execution. int_bitwidth = theano.gof.cmodule.python_int_bitwidth() if int_bitwidth == 64: numpy_unsupported_dtypes = ('uint64',) if int_bitwidth == 32: numpy_unsupported_dtypes = ('uint32', 'int64', 'uint64') if repeats.dtype in numpy_unsupported_dtypes: raise TypeError( ("dtypes %s are not supported by numpy.repeat " "for the 'repeats' parameter, " % numpy_unsupported_dtypes), repeats.dtype) if self.axis is None: broadcastable=[False] else: try: const_reps = basic.get_scalar_constant_value(repeats) except basic.NotScalarConstantError: const_reps = None if const_reps == 1: broadcastable = x.broadcastable else: broadcastable = list(x.broadcastable) broadcastable[self.axis] = False out_type = theano.tensor.TensorType(x.dtype, broadcastable) return theano.Apply(self, [x, repeats], [out_type()])