Exemplo n.º 1
0
def eval_cleverhans():

    # Set test phase
    learning_phase = K.learning_phase()
    K.set_learning_phase(0)

    # Pre-process images
    images_tf = images.astype(K.floatx())
    images_tf /= 255.

    # Wrapper for the Keras model
    model_wrap = KerasModelWrapper(model)

    # Initialize attack
    if attack_params_dict['attack'] == 'fgsm':
        attack = FastGradientMethod(model_wrap, sess=K.get_session())
        attack_params = {'eps': attack_params_dict['eps'], 'clip_min': 0., 
                         'clip_max': 1.}
    elif attack_params_dict['attack'] == 'deepfool':
        attack = DeepFool(model_wrap, sess=K.get_session())
        attack_params = {'clip_min': 0., 'clip_max': 1.}
    elif attack_params_dict['attack'] == 'madry':
        attack = ProjectedGradientDescent(model_wrap, sess=K.get_session())
        attack_params = {'clip_min': 0., 'clip_max': 1.}
    elif attack_params_dict['attack'] == 'carlini':
        attack = CarliniWagnerL2(model_wrap, sess=K.get_session())
        attack_params = {'clip_min': 0., 'clip_max': 1.}
    else:
        raise NotImplementedError()

    # Define input TF placeholder
    x = tf.placeholder(K.floatx(), shape=(None,) + images.shape[1:])
    y = tf.placeholder(K.floatx(), shape=(None,) + (labels.shape[-1],))

    # Define adversarial predictions symbolically
    x_adv = attack.generate(x, **attack_params)
    x_adv = tf.stop_gradient(x_adv)
    predictions_adv = model(x_adv)

    # Evaluate the accuracy of the model on adversarial examples
    eval_par = {'batch_size': batch_size}
    # feed_dict = {K.learning_phase(): attack_params_dict['learning_phase']}
    # acc_adv = model_eval(K.get_session(), x, y, predictions_adv, images, 
    #                      labels, feed=feed_dict, args=eval_par)
    acc_adv = model_eval(K.get_session(), x, y, predictions_adv, images_tf, 
                         labels, args=eval_par)

    print('Aversarial accuracy against %s: %.4f\n' %
          (attack_params_dict['attack'], acc_adv))

    # Set original phase
    K.set_learning_phase(learning_phase)

    return acc_adv
Exemplo n.º 2
0
    def build(self, input_shape):
        assert len(input_shape) == 3
        n_classes = input_shape[2]
        n_steps = input_shape[1]
        assert n_steps is None or n_steps >= 2
        self.input_spec = [
            InputSpec(dtype=K.floatx(), shape=(None, n_steps, n_classes))
        ]

        self.U = self.add_weight(shape=(n_classes, n_classes),
                                 initializer=self.init,
                                 name='U',
                                 regularizer=self.U_regularizer,
                                 constraint=self.U_constraint)

        self.b_start = self.add_weight(shape=(n_classes, ),
                                       initializer='zero',
                                       name='b_start',
                                       regularizer=self.b_start_regularizer,
                                       constraint=self.b_start_constraint)

        self.b_end = self.add_weight(shape=(n_classes, ),
                                     initializer='zero',
                                     name='b_end',
                                     regularizer=self.b_end_regularizer,
                                     constraint=self.b_end_constraint)

        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights

        self.built = True
Exemplo n.º 3
0
def path_energy0(y, x, U, mask=None):
    """Path energy without boundary potential handling."""
    n_classes = K.shape(x)[2]
    y_one_hot = K.one_hot(y, n_classes)

    # Tag path energy
    energy = K.sum(x * y_one_hot, 2)
    energy = K.sum(energy, 1)

    # Transition energy
    y_t = y[:, :-1]
    y_tp1 = y[:, 1:]
    U_flat = K.reshape(U, [-1])
    # Convert 2-dim indices (y_t, y_tp1) of U to 1-dim indices of U_flat:
    flat_indices = y_t * n_classes + y_tp1
    U_y_t_tp1 = K.gather(U_flat, flat_indices)

    if mask is not None:
        mask = K.cast(mask, K.floatx())
        y_t_mask = mask[:, :-1]
        y_tp1_mask = mask[:, 1:]
        U_y_t_tp1 *= y_t_mask * y_tp1_mask

    energy += K.sum(U_y_t_tp1, axis=1)

    return energy
Exemplo n.º 4
0
def viterbi_decode(x, U, b_start=None, b_end=None, mask=None):
    """Computes the best tag sequence y for a given input x, i.e. the one that
    maximizes the value of path_energy."""
    x = add_boundary_energy(x, b_start, b_end, mask)

    alpha_0 = x[:, 0, :]
    gamma_0 = K.zeros_like(alpha_0)
    initial_states = [gamma_0, alpha_0]
    _, gamma = _forward(
        x,
        lambda B: [K.cast(K.argmax(B, axis=1), K.floatx()),
                   K.max(B, axis=1)], initial_states, U, mask)
    y = _backward(gamma, mask)
    return y
Exemplo n.º 5
0
def get_model_memory_usage(net_list, batch_size, input_shape, target_shape, return_dict):
    try:
        model = dag_2_cnn(cgp_2_dag(net_list), 0, input_shape, target_shape, compile=False)
    except (tf.errors.ResourceExhaustedError, KeyError) as e:
        print(e)
        return_dict["memory"] = 1000
        return

    shapes_mem_count = 0
    internal_model_mem_count = 0

    for l in model.layers:
        single_layer_mem = 1
        out_shape = l.output_shape
        if type(out_shape) is list:
            out_shape = out_shape[0]
        for s in out_shape:
            if s is None:
                continue
            single_layer_mem *= s

        shapes_mem_count += single_layer_mem

    trainable_count = np.sum([K.count_params(p) for p in model.trainable_weights])
    non_trainable_count = np.sum([K.count_params(p) for p in model.non_trainable_weights])
    number_size = 4.0

    if K.floatx() == 'float16':
        number_size = 2.0
    if K.floatx() == 'float64':
        number_size = 8.0

    total_memory = number_size * (batch_size * shapes_mem_count + trainable_count + non_trainable_count)
    gbytes = np.round(total_memory / (1024.0 ** 3), 3) + internal_model_mem_count
    K.clear_session()

    return_dict["memory"] = gbytes
Exemplo n.º 6
0
def add_boundary_energy(x, b_start=None, b_end=None, mask=None):
    """Given the observations x, it adds the start boundary energy b_start (resp.
    end boundary energy b_end on the start (resp. end) elements and multiplies
    the mask."""
    if mask is None:
        if b_start is not None:
            x = K.concatenate([x[:, :1, :] + b_start, x[:, 1:, :]], axis=1)
        if b_end is not None:
            x = K.concatenate([x[:, :-1, :], x[:, -1:, :] + b_end], axis=1)
    else:
        mask = K.cast(mask, K.floatx())
        mask = K.expand_dims(mask, 2)
        x *= mask
        if b_start is not None:
            mask_r = K.concatenate([K.zeros_like(mask[:, :1]), mask[:, :-1]],
                                   axis=1)
            start_mask = K.cast(K.greater(mask, mask_r), K.floatx())
            x = x + start_mask * b_start
        if b_end is not None:
            mask_l = K.concatenate(
                [mask[:, 1:], K.zeros_like(mask[:, -1:])], axis=1)
            end_mask = K.cast(K.greater(mask, mask_l), K.floatx())
            x = x + end_mask * b_end
    return x
Exemplo n.º 7
0
    def call(self, x, mask=None):
        features_dim = self.features_dim
        step_dim = self.step_dim

        e = K.reshape(
            K.dot(K.reshape(x, (-1, features_dim)),
                  K.reshape(self.W, (features_dim, 1))),
            (-1, step_dim))  # e = K.dot(x, self.W)
        if self.bias:
            e += self.b
        e = K.tanh(e)

        a = K.exp(e)
        # apply mask after the exp. will be re-normalized next
        if mask is not None:
            # cast the mask to floatX to avoid float64 upcasting in theano
            a *= K.cast(mask, K.floatx())
        # in some cases especially in the early stages of training the sum may be almost zero
        # and this results in NaN's. A workaround is to add a very small positive number ε to the sum.
        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())
        a = K.expand_dims(a)

        c = K.sum(a * x, axis=1)
        return c
Exemplo n.º 8
0
def orthonorm_op(x, epsilon=1e-7):
    '''
    Computes a matrix that orthogonalizes the input matrix x

    x:      an n x d input matrix
    eps:    epsilon to prevent nonzero values in the diagonal entries of x

    returns:    a d x d matrix, ortho_weights, which orthogonalizes x by
                right multiplication
    '''
    x_2 = K.dot(K.transpose(x), x)
    x_2 += K.eye(K.int_shape(x)[1]) * epsilon
    L = tf.cholesky(x_2)
    ortho_weights = tf.transpose(tf.matrix_inverse(L)) * tf.sqrt(
        tf.cast(tf.shape(x)[0], dtype=K.floatx()))
    return ortho_weights
def lab2rgb(x):
    """
    Converts a CIE Lab image into RGB, dtype=K.floatx() in [0, 1]

    Parameters
    ----------
    x : ndarray
        Lab image

    Returns
    -------
    x_rgb : ndarray
        RGB image
    """
    x_rgb = skico.lab2rgb(x)
    x_rgb = x_rgb.astype(K.floatx())

    return x_rgb
Exemplo n.º 10
0
def get_activations(activation_function, batch_gen):
    """
    Computes the activations of a data set at one layer of the model in a 
    "delayed" way (for memory and computation efficiency) and return them as a
    dask array. 

    See: https://docs.dask.org/en/latest/delayed.html
    """

    layer_shape = K.int_shape(activation_function.outputs[0])[1:]
    layer_dim = np.prod(K.int_shape(activation_function.outputs[0])[1:])
    n_images = batch_gen.n_images
    n_aug = batch_gen.aug_per_im
    batch_size = batch_gen.batch_size

    # Delayed computation of the activations of a batch
    @dask.delayed
    def batch_activation():
        batch_images, _ = next(batch_gen())
        return activation_function([batch_images, 0])[0]

    # Delayed iteration over the data set
    activations_delayed = [batch_activation() for _
            in range(batch_gen.n_batches)]
    activations_da_list = [da.from_delayed(
            activation_delayed,
            shape=(batch_size * n_aug, ) + layer_shape,
            dtype=K.floatx())
        for activation_delayed in activations_delayed]
    activations_da = da.concatenate(activations_da_list, axis=0)

    # The last batch can be smaller
    activations_da = activations_da[:n_images * n_aug]

    # Reshape the activations such that 
    # shape = (n_diff_images, layer_dim, n_aug)
    activations_da = da.reshape(activations_da, 
                                (activations_da.shape[0], layer_dim))
    activations_da = da.transpose(da.reshape(activations_da.T, 
                                             (layer_dim, n_images, n_aug)),
                                  (1, 0, 2))

    return activations_da
Exemplo n.º 11
0
 def call(self, inputs, **kwargs):
     padding = self.padding
     pool_size = self.pool_size
     strides = self.strides
     if K.backend() == "tensorflow":
         ksize = [1, pool_size[0], pool_size[1], 1]
         padding = padding.upper()
         strides = [1, strides[0], strides[1], 1]
         output, argmax = tf.nn.max_pool_with_argmax(inputs,
                                                     ksize=ksize,
                                                     strides=strides,
                                                     padding=padding)
     else:
         errmsg = "{} backend is not supported for layer {}".format(
             K.backend(),
             type(self).__name__)
         raise NotImplementedError(errmsg)
     argmax = K.cast(argmax, K.floatx())
     return [output, argmax]
Exemplo n.º 12
0
def _forward(x, reduce_step, initial_states, U, mask=None):
    """Forward recurrence of the linear chain crf."""
    def _forward_step(energy_matrix_t, states):
        alpha_tm1 = states[-1]
        new_states = reduce_step(K.expand_dims(alpha_tm1, 2) + energy_matrix_t)
        return new_states[0], new_states

    U_shared = K.expand_dims(K.expand_dims(U, 0), 0)

    if mask is not None:
        mask = K.cast(mask, K.floatx())
        mask_U = K.expand_dims(K.expand_dims(mask[:, :-1] * mask[:, 1:], 2), 3)
        U_shared = U_shared * mask_U

    inputs = K.expand_dims(x[:, 1:, :], 2) + U_shared
    inputs = K.concatenate([inputs, K.zeros_like(inputs[:, -1:, :, :])],
                           axis=1)

    last, values, _ = K.rnn(_forward_step, inputs, initial_states)
    return last, values
def rgb2lab(x):
    """
    Converts an RGB image into the CIE Lab space

    Parameters
    ----------
    x : ndarray
        RGB image

    Returns
    -------
    x_lab : ndarray
        Lab image
    """
    # Convert the input image into float32
    if x.dtype == 'uint8':
        x = x.astype(K.floatx())
        x /= 255.
    else:
        if np.max(x) > 1.:
            x /= 255.
    x_lab = skico.rgb2lab(x)

    return x_lab
    def flow_dask(self,
                  x,
                  y=None,
                  batch_size=32,
                  aug_per_im=1,
                  shuffle=True,
                  sample_weight=None,
                  seed_shuffle=None,
                  save_to_dir=None,
                  save_prefix='',
                  save_format='png',
                  subset=None,
                  dtype=K.floatx()):

        # Image cropping
        if self.target_shape is None:
            target_shape = list(x.shape)[1:]
        else:
            target_shape = self.target_shape

        # Dask array iterator
        return DaskArrayIterator(x,
                                 y,
                                 target_shape,
                                 self,
                                 batch_size=batch_size,
                                 aug_per_im=aug_per_im,
                                 shuffle=shuffle,
                                 sample_weight=sample_weight,
                                 seed=seed_shuffle,
                                 data_format=self.data_format,
                                 save_to_dir=save_to_dir,
                                 save_prefix=save_prefix,
                                 save_format=save_format,
                                 subset=subset,
                                 dtype=dtype)
Exemplo n.º 15
0
import sklearn
import h5py
from tqdm import tqdm
from scipy import signal
from scipy import stats
from scipy import fftpack
from statsmodels.robust import mad
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

print(tf.version.VERSION)
print(tf.keras.__version__)
print(K.floatx())
print("Finished Importing Files")
window = 150000
n_feat = 18
n_fold = 3
seed = 65
print("Finished initializing variables")
data = pandas.read_csv("train.csv",
                       header=None,
                       dtype=numpy.float64,
                       float_precision='high')
print("Finished Reading the CSV File")
data = data.to_numpy()
output = data[:, 1]
training = numpy.float32(data[:, 0])
del data
Exemplo n.º 16
0
def test(images, labels, batch_size, model, model_adv, image_params_dict, 
         attack_params_dict, output_file=None, do_print=True):
    """
    Tests the performance of a model on adversarial images. The adversarial
    images are computed according to the attack specified in the arguments.

    Parameters
    ----------
    images : dask array
        The set of images

    labels : dask array
        The ground truth labels

    batch_size : int
        Batch size

    model : Keras Model
        The model

    model_adv : Keras Model
        The model used to generate adversarial examples

    image_params_dict : dict
        Dictionary of data augmentation parameters

    attack_params_dict : dict
        Dictionary of the attack parameters

    output_file : str or None
        The outfile to write the results 

    do_print : bool
        Whether to print the adversarial accuracy and MSE

    Returns
    -------
    results_dict : dict
        Dictionary containing some performance metrics
    """

    # Set test phase and get session
    sess = K.get_session()
    if isinstance(K.learning_phase(), int):
        learning_phase = K.learning_phase()
        K.set_learning_phase(0)
    
    # Initialize adversarial attack
    attack, attack_params, bs = init_attack(model_adv, attack_params_dict, 
                                            sess)
    if bs:
        batch_size = bs

    # Create batch generator
    image_gen = data_input.get_generator(images, **image_params_dict)
    batch_gen = batch_generator(image_gen, images, labels, batch_size, 
                                aug_per_im=1, shuffle=False)
    n_batches_per_epoch = int(np.ceil(float(images.shape[0]) / batch_size))

    # Define input TF placeholder
    if image_params_dict['crop_size']:
        image_shape = image_params_dict['crop_size']
    else:
        image_shape = images.shape[1:]
    x = tf.placeholder(K.floatx(), shape=(bs,) + tuple(image_shape))
    y = tf.placeholder(K.floatx(), shape=(bs,) + (labels.shape[-1],))

    # Define adversarial predictions symbolically
    x_adv = attack.generate(x, **attack_params)
    x_adv = tf.stop_gradient(x_adv)
    predictions_adv = model(x_adv)

    # Define accuracy symbolically
    correct_preds = tf.equal(tf.argmax(y, axis=-1), 
                             tf.argmax(predictions_adv, axis=-1))
    acc_value = tf.reduce_mean(tf.to_float(correct_preds))

    # Define mean squared error symbolically
    mse_value = tf.reduce_mean(tf.square(tf.subtract(x, x_adv)))

    # Init results variables
    accuracy = 0.0
    mse = 0.0

    # Initialize matrix to store the predictions.
#     predictions = np.zeros([images.shape[0], labels.shape[-1]])

    with sess.as_default():
        init = 0
        for _ in tqdm(range(n_batches_per_epoch)):
            batch = next(batch_gen())
            this_batch_size = batch[0].shape[0]

            # Evaluate accuracy
            if isinstance(batch[1], (list, )):
                yy = batch[1][0]
            else:
                yy = batch[1]
            batch_acc = acc_value.eval(feed_dict={x: batch[0], y: yy})

            # Evaluate MSE
            batch_mse = mse_value.eval(feed_dict={x: batch[0]})

            # Adversarial predictions
#             predictions[init:init+this_batch_size, :] = \
#                     predictions_adv.eval(feed_dict={x: batch[0]})

            # Update accuracy and MSE
            accuracy += (this_batch_size * batch_acc)
            mse += (this_batch_size * batch_mse)

            init += this_batch_size

    accuracy /= images.shape[0]
    mse /= images.shape[0]

    # Compute accuracy
#     acc_post = np.divide(np.sum(np.argmax(predictions, axis=1) == \
#                                 np.argmax(labels, axis=1)), 
#                          float(images.shape[0]))


    if do_print:
        print('Aversarial accuracy against %s: %.4f' %
              (attack_params_dict['attack'], accuracy))
#         print('Aversarial accuracy against %s: %.4f' %
#               (attack_params_dict['attack'], acc_post))
        print('MSE between %s adversaries and originals: %.4f\n' %
              (attack_params_dict['attack'], mse))

    if output_file is not None:
        _write_results(accuracy, output_file)

    # Set original phase
    if isinstance(K.learning_phase(), int):
        K.set_learning_phase(learning_phase)

    return accuracy, mse
Exemplo n.º 17
0
def test_adv(images, labels, batch_size, model, adv_model, daug_params, 
             attack_params):
    """
    Tests the performance of a model on adversarial images. The adversarial
    images are computed according to the attack specified in the arguments.

    Parameters
    ----------
    images : dask array
        The set of images

    labels : dask array
        The ground truth labels

    batch_size : int
        Batch size

    model : Keras Model
        The model

    adv_model : Keras Model
        The model used to generate adversarial examples

    daug_params : dict
        Dictionary of data augmentation parameters

    attack_params : dict
        Dictionary of the attack parameters

    Returns
    -------
    results_dict : dict
        Dictionary containing some performance metrics
    """

    # Get session
    sess = K.get_session()
    
    # Initialize adversarial attack
    attack, attack_params_cleverhans, bs = init_attack(
            adv_model, attack_params, sess)
    if bs:
        batch_size = bs

    n_images = images.shape[0]
    n_classes = labels.shape[1]
    n_batches_per_epoch = int(np.ceil(float(n_images) / batch_size))

    # Create batch generator
    image_gen = get_generator(images, **daug_params)
    batch_gen = batch_generator(image_gen, images, labels, batch_size, 
                                aug_per_im=1, shuffle=False)

    # Define input TF placeholder
    if daug_params['crop_size']:
        image_shape = daug_params['crop_size']
    else:
        image_shape = images.shape[1:]
    x = tf.placeholder(K.floatx(), shape=(bs,) + tuple(image_shape))
    y = tf.placeholder(K.floatx(), shape=(bs,) + (n_classes,))

    # Define adversarial predictions symbolically
    x_adv = attack.generate(x, **attack_params_cleverhans)
    x_adv = tf.stop_gradient(x_adv)
    predictions_adv = model(x_adv)

    # Define accuracy and mean squared error symbolically
    correct_preds = tf.equal(tf.argmax(y, axis=-1), 
                             tf.argmax(predictions_adv, axis=-1))
    acc_value = tf.reduce_mean(tf.to_float(correct_preds))
    mse_value = tf.reduce_mean(tf.square(tf.subtract(x, x_adv)))

    # Init results variables
    accuracy = 0.0
    mse = 0.0

    with sess.as_default():
        init = 0
        for _ in tqdm(range(n_batches_per_epoch)):
            batch = next(batch_gen())
            this_batch_size = batch[0].shape[0]

            # Evaluate accuracy
            if isinstance(batch[1], (list, )):
                yy = batch[1][0]
            else:
                yy = batch[1]

            # Evaluate accuracy and MSE
            batch_acc = acc_value.eval(feed_dict={x: batch[0], y: yy,
                                                  K.learning_phase(): 0})
            accuracy += (this_batch_size * batch_acc)
            batch_mse = mse_value.eval(feed_dict={x: batch[0],
                                       K.learning_phase(): 0})
            mse += (this_batch_size * batch_mse)

            init += this_batch_size

    accuracy /= n_images
    mse /= n_images

    results_dict = {'mean_acc': accuracy,
                    'mean_mse': mse}

    return results_dict
Exemplo n.º 18
0
 def call(self, inputs, **kwargs):
     mask = K.not_equal(inputs, 0)
     return K.cast(mask, K.floatx())
    def __init__(self,
                 x,
                 y,
                 target_shape,
                 image_data_generator,
                 batch_size=32,
                 aug_per_im=1,
                 shuffle=False,
                 sample_weight=None,
                 seed=None,
                 data_format=None,
                 save_to_dir=None,
                 save_prefix='',
                 save_format='png',
                 subset=None,
                 ignore_class_split=False,
                 dtype=K.floatx()):
        # Note that most lines are adapted from the __init__ function of
        # NumpyArrayIterator. Importantly, np.asarray(x) (or on y) is never
        # performed here, since the memory could be filled up.
        self.dtype = dtype
        if (type(x) is tuple) or (type(x) is list):
            if type(x[1]) is not list:
                x_misc = [x[1]]
            else:
                x_misc = [xx for xx in x[1]]
            x = x[0]
            for xx in x_misc:
                if len(x) != len(xx):
                    raise ValueError(
                        'All of the arrays in `x` '
                        'should have the same length. '
                        'Found a pair with: len(x[0]) = %s, len(x[?]) = %s' %
                        (len(x), len(xx)))
        else:
            x_misc = []

        if (type(y) is tuple) or (type(y) is list):
            if type(y[1]) is not list:
                y_misc = [y[1]]
            else:
                y_misc = [yy for yy in y[1]]
            y = y[0]
            for yy in y_misc:
                if len(y) != len(yy):
                    raise ValueError(
                        'All of the arrays in `y` '
                        'should have the same length. '
                        'Found a pair with: len(y[0]) = %s, len(y[?]) = %s' %
                        (len(y), len(yy)))
        else:
            y_misc = []

        if y is not None and len(x) != len(y):
            raise ValueError('x (images tensor) and y (labels) '
                             'should have the same length. '
                             'Found: X.shape = %s, y.shape = %s' %
                             (x.shape, y.shape))
        if sample_weight is not None and len(x) != len(sample_weight):
            raise ValueError('`x` (images tensor) and `sample_weight` '
                             'should have the same length. '
                             'Found: x.shape = %s, sample_weight.shape = %s' %
                             (x.shape, sample_weight.shape))
        if subset is not None:
            if subset not in {'training', 'validation'}:
                raise ValueError('Invalid subset name:', subset,
                                 '; expected "training" or "validation".')
            split_idx = int(len(x) * image_data_generator._validation_split)

            if (y is not None and not ignore_class_split
                    and not np.array_equal(
                        da.unique(y[:split_idx]).compute(),
                        da.unique(y[split_idx:])).compute()):
                raise ValueError('Training and validation subsets '
                                 'have different number of classes after '
                                 'the split. If your numpy arrays are '
                                 'sorted by the label, you might want '
                                 'to shuffle them.')

            if subset == 'validation':
                x = x[:split_idx]
                x_misc = [xx[:split_idx] for xx in x_misc]
                if y is not None:
                    y = y[:split_idx]
            else:
                x = x[split_idx:]
                x_misc = [xx[split_idx:] for xx in x_misc]
                if y is not None:
                    y = y[split_idx:]

        # Define the dask arrays and the chunk size. The size is assumed to be
        # the 0th dimension of the chunks and the other dimensions should be
        # equal to x.shape[1:]
        self.x_dask = x
        self.x_misc_dask = x_misc
        self.chunk_size = self.x_dask.chunks[0][0]

        # First chunk
        self.x = np.asarray(self.x_dask[:self.chunk_size], dtype=self.dtype)
        self.x_misc = [
            np.asarray(xx[:self.chunk_size]) for xx in self.x_misc_dask
        ]
        self.chunk_index = 0

        if y is not None:
            self.y_dask = y
            self.y = np.asarray(self.y_dask[:self.chunk_size])
            self.y_misc_dask = y_misc
            self.y_misc = [
                np.asarray(yy[:self.chunk_size]) for yy in self.y_misc_dask
            ]
        else:
            self.y_dask = None
            self.y = None
            self.y_misc_dask = None
            self.y_misc = None

        if sample_weight is not None:
            self.sample_weight_dask = sample_weight
            self.sample_weight = np.asarray(
                self.sample_weight_dask[:self.chunk_size])
        else:
            self.sample_weight_dask = None
            self.sample_weight = None

        if self.x.ndim != 4:
            raise ValueError(
                'Input data in `NumpyArrayIterator` '
                'should have rank 4. You passed an array '
                'with shape', self.x.shape)
        channels_axis = 3 if data_format == 'channels_last' else 1
        if self.x.shape[channels_axis] not in {1, 3, 4}:
            warnings.warn('NumpyArrayIterator is set to use the '
                          'data format convention "' + data_format + '" '
                          '(channels on axis ' + str(channels_axis) +
                          '), i.e. expected either 1, 3, or 4 '
                          'channels on axis ' + str(channels_axis) + '. '
                          'However, it was passed an array with shape ' +
                          str(self.x.shape) + ' (' +
                          str(self.x.shape[channels_axis]) + ' channels).')

        self.n_aug = aug_per_im
        self.n_images = self.x_dask.shape[0]
        self.target_shape = target_shape
        self.image_data_generator = image_data_generator
        self.data_format = data_format
        self.save_to_dir = save_to_dir
        self.save_prefix = save_prefix
        self.save_format = save_format

        super(DaskArrayIterator, self).__init__(self.chunk_size, batch_size,
                                                shuffle, seed)