def shiftGaugeField(self, gaugeField, cpt, sign): # Moving one site forwards is equivalent to shifting the whole field # backwards, hence the minus sign (active/passive transform) gaugeFieldShifted = tf.roll(gaugeField, -sign, cpt) if cpt != 0: return gaugeFieldShifted # Apply reflecting BC's by setting links at the boundary to # corresponding values from unshifted field if sign == +1: slicePos = self.latShape[cpt] - 2 else: slicePos = 1 # For gathering from the shifted field (gathering from the variable is slow) indices = FieldTools.sliceIndices(self.latShape, cpt, slicePos) # For scattering onto the boundary boundaryIndices = FieldTools.boundaryIndices(self.latShape, cpt, sign) updates = tf.gather_nd(gaugeFieldShifted, indices) boundaryUpdates = tf.gather_nd(gaugeField, boundaryIndices) gaugeFieldShifted = tf.tensor_scatter_nd_update( gaugeFieldShifted, boundaryIndices, updates) if sign == -1: # Set the r-links at the origin to the identity rOriginIndices = tf.stack( tf.meshgrid(0, tf.range(self.latShape[1]), tf.range(self.latShape[2]), 0, indexing="ij"), -1) rOriginUpdates = tf.eye(2, batch_shape=tf.shape(rOriginIndices)[0:-1], dtype=tf.complex128) gaugeFieldShifted = tf.tensor_scatter_nd_update( gaugeFieldShifted, rOriginIndices, rOriginUpdates) return gaugeFieldShifted
def shiftScalarField(self, scalarField, cpt, sign): # Moving one site forwards is equivalent to shifting the whole field # backwards, hence the minus sign (active/passive transform) scalarFieldShifted = tf.roll(scalarField, -sign, cpt) if cpt != 0: return scalarFieldShifted # Apply reflecting BC's by setting links at the boundary to # corresponding values from unshifted field if sign == +1: slicePos = self.latShape[cpt] - 2 else: slicePos = 1 # For gathering from the shifted field (gathering from the variable is slow) indices = FieldTools.sliceIndices(self.latShape, cpt, slicePos) # For scattering onto the boundary boundaryIndices = FieldTools.boundaryIndices(self.latShape, cpt, sign) updates = tf.gather_nd(scalarFieldShifted, indices) scalarFieldShifted = tf.tensor_scatter_nd_update( scalarFieldShifted, boundaryIndices, updates) return scalarFieldShifted