def on_loss_calculation_end(self, training_context):
        """Returns mixed inputs, pairs of targets, and lambda"""
        train_data = training_context['train_data']
        x = None
        y = None
        if get_backend() == 'pytorch':
            x = train_data.value_list[0].clone()  # input
            y = train_data.value_list[1].clone()  # label
        elif get_backend() == 'tensorflow':
            x = copy.deepcopy(train_data.value_list[0])  # input
            y = copy.deepcopy(train_data.value_list[1]) # label
        model = training_context['current_model']
        if self.alpha > 0:
            lam = np.random.beta(self.alpha, self.alpha)
        else:
            lam = 1

        batch_size = int_shape(x)[0]
        index = cast(arange(batch_size),'int64')
        index=shuffle(index)

        this_loss=None
        if get_backend()=='pytorch':
            y_a, y_b = y, y[index]
            bbx1, bby1, bbx2, bby2 = self.rand_bbox(x.shape[3], x.shape[2], lam)
            x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
            # adjust lambda to exactly match pixel ratio
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.shape[3] * x.shape[2]))
            pred = model(to_tensor(x, requires_grad=True))
            this_loss = lam * self.loss_criterion(pred, y_a.long()) + (1 - lam) * self.loss_criterion(pred, y_b.long())
        elif get_backend() == 'tensorflow':

            y1 = tf.gather(y,index,axis=0)
            x1= tf.gather(x,index,axis=0)
            y_a, y_b = y, y1
            bbx1, bby1, bbx2, bby2 = self.rand_bbox(x.shape[2], x.shape[1], lam)
            filter=np.zeros(int_shape(x))
            filter[:, bbx1:bbx2, bby1:bby2, :] =1
            filter=to_tensor(x)
            x=x*(1-filter)+x1*filter
            #x[:, bbx1:bbx2, bby1:bby2, :] = x1[:, bbx1:bbx2, bby1:bby2,:]
            # adjust lambda to exactly match pixel ratio
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.shape[2] * x.shape[1]))
            pred = model(to_tensor(x, requires_grad=True))
            loss1=self.loss_criterion(pred, y_a)
            loss2=self.loss_criterion(pred, y_b)
            this_loss = lam *loss1  + (1 - lam) * loss2


        if training_context['current_batch'] == 0:
            for item in x:
                item = unnormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(to_numpy(item))
                item = unnormalize(0, 255)(item)
                array2image(item).save('Results/cutmix_{0}.jpg'.format(get_time_suffix()))



        training_context['current_loss'] = training_context['current_loss'] + this_loss *self.loss_weight
        if training_context['is_collect_data']:
            training_context['losses'].collect('cutmix_loss', training_context['steps'], float(to_numpy(this_loss * self.loss_weight)))
    def on_loss_calculation_end(self, training_context):
        """Returns mixed inputs, pairs of targets, and lambda"""
        train_data = training_context['train_data']
        x = None
        y = None
        x = train_data.value_list[0].copy().detach()  # input
        y = train_data.value_list[1].copy().detach()  # label
        model = training_context['current_model']

        lam = builtins.min(
            builtins.max(np.random.beta(self.alpha, self.alpha), 0.3), 0.7)

        batch_size = int_shape(x)[0]
        index = arange(batch_size)
        index = cast(shuffle(index), 'long')
        this_loss = None
        mixed_x = None
        if get_backend() == 'pytorch':
            mixed_x = lam * x + (1 - lam) * x[index, :]
            pred = model(to_tensor(mixed_x, requires_grad=True))
            y_a, y_b = y, y[index]
            this_loss = lam * self.loss_criterion(pred, y_a.long()) + (
                1 - lam) * self.loss_criterion(pred, y_b.long())
        elif get_backend() == 'tensorflow':
            x1 = tf.gather(x, index, axis=0)
            y1 = tf.gather(y, index, axis=0)
            mixed_x = lam * x + (1 - lam) * x1
            pred = model(to_tensor(mixed_x, requires_grad=True))
            y_a, y_b = y, y1

            this_loss = lam * self.loss_criterion(
                pred, y_a) + (1 - lam) * self.loss_criterion(pred, y_b)

        training_context['current_loss'] = training_context[
            'current_loss'] + this_loss * self.loss_weight
        if training_context['is_collect_data']:
            training_context['losses'].collect(
                'mixup_loss', training_context['steps'],
                float(to_numpy(this_loss * self.loss_weight)))

        if training_context['current_batch'] == 0:
            for item in mixed_x:
                if self.save_path is None and not is_in_colab():
                    item = unnormalize([0.485, 0.456, 0.406],
                                       [0.229, 0.224, 0.225])(to_numpy(item))
                    item = unnormalize(0, 255)(item)
                    array2image(item).save('Results/mixup_{0}.jpg'.format(
                        get_time_suffix()))
                elif self.save_path is not None:
                    item = unnormalize([0.485, 0.456, 0.406],
                                       [0.229, 0.224, 0.225])(to_numpy(item))
                    item = unnormalize(0, 255)(item)
                    array2image(item).save(
                        os.path.join(self.save_path, 'mixup_{0}.jpg'.format(
                            get_time_suffix())))
    def plot_tile_image(self, training_context):
        axis = 1
        if get_backend() == 'tensorflow':
            axis = -1

        tile_images_list = []
        input = None
        target = None
        output = None
        is_label_mask = False
        data_feed = training_context['data_feed']
        data = training_context['train_data']
        model = training_context['current_model']
        if model.output_shape[model.filter_index] > 2:
            is_label_mask = True
        # if len(data) >= 3:
        for data_key in data.key_list:
            if data_key == data_feed[model.signature.inputs.key_list[0]]:
                input = data[data_feed[model.signature.inputs.key_list[0]]]
                model.eval()
                if is_label_mask:
                    output = to_numpy(argmax(model(input), axis=axis))
                else:
                    output = to_numpy(
                        expand_dims(cast(argmax(model(input), axis=axis),
                                         input.dtype),
                                    axis=-1))

                model.train()

            # elif data_key == data_feed[model.signature.outputs.key_list[0]]:
            #     output = data[data_feed[model.signature.outputs.key_list[0]]]
            #     if output.max() < 0:
            #         output = exp(output)

            elif (
                    'target' in data_key or 'label' in data_key
                    or 'mask' in data_key
            ) and not 'output' in data_key and data_key in data_feed.value_list:
                target = to_numpy(data[data_key])
        output_arr = None
        if 'alpha' not in data:
            output_arr = output.copy()
            if is_label_mask:
                target = label2color(target, self.palette)
                output = label2color(output, self.palette)
        else:
            if get_backend() == 'tensorflow':
                output = output[:, :, :, 1:2] * argmax(output, axis)
            else:
                output = (output[:, 1:2, :, :] *
                          argmax(output, axis)).transpose(0, 2, 3, 1)
            target = to_numpy(data['alpha'])

        input_arr = []
        input = to_numpy(input)
        for i in range(len(input)):
            input_arr.append(self.reverse_image_transform(input[i]))
        # input_arr=np.asarray(input_arr)
        tile_images_list.append(input_arr)

        if is_label_mask:
            tile_images_list.append(target)
            tile_images_list.append(output)
        else:
            target_arr = target

            if len(target.shape) < len(int_shape(input)):
                if get_backend() == 'tensorflow':
                    target_arr = np.expand_dims(target, -1)
                else:
                    target_arr = np.expand_dims(target, 1)

            if 'alpha' not in data:
                target_arr[target_arr > 0] = 1

            background = np.ones_like(target_arr) * self.background

            tile_images_list.append(target_arr * input_arr +
                                    (1 - target_arr) * background)

            tile_images_list.append(output_arr * input_arr +
                                    (1 - output_arr) * background)

        # if self.tile_image_include_mask:
        #     tile_images_list.append(input*127.5+127.5)
        fig = tile_rgb_images(*tile_images_list,
                              save_path=os.path.join(
                                  self.save_path, self.tile_image_name_prefix),
                              imshow=True)
        if ctx.enable_tensorboard and ctx.summary_writer is not None:
            ctx.summary_writer.add_figure(
                training_context['training_name'] + '/plot/segtile_image',
                fig,
                global_step=training_context['steps'],
                close=True,
                walltime=time.time())
        plt.close()