Beispiel #1
0
    def render(self, variables):
        #variables = copy.deepcopy(variables)
        iteration = variables['__iteration__']
        if iteration >= self.num_block:
            return True

        processed_image = {}

        for n_ch_image in self.n_ch_images:
            images = variables[n_ch_image]
            if isinstance(images, list):
                images = chainer.functions.concat(
                    [F.expand_dims(img, axis=0) for img in images], axis=0)

            processed_image[n_ch_image] = images
            for i, color in zip(range(self.n_ch), cycle(self.color_pallete)):
                index = 'NchImageVisualizer.{}.__{}_ch_images'.format(
                    n_ch_image, i)
                variables[index] = chainer.functions.concat((
                    F.expand_dims(images[:, i] * color[0], axis=1),
                    F.expand_dims(images[:, i] * color[1], axis=1),
                    F.expand_dims(images[:, i] * color[2], axis=1),
                ),
                                                            axis=1)

        for subtract_pair in self.subtract_images:
            assert subtract_pair[0] in processed_image
            assert subtract_pair[1] in processed_image
            pair_string = '-'.join(subtract_pair)
            # make subtract images
            lhs_img = utils.unwrapped(processed_image[subtract_pair[0]])
            rhs_img = utils.unwrapped(processed_image[subtract_pair[1]])
            subtract_img = lhs_img - rhs_img * 2
            abs_img = np.absolute(subtract_img)
            for i, color in zip(range(self.n_ch), cycle(self.color_pallete)):
                index = 'NchImageVisualizer.{}.__{}_ch_subtract_images'.format(
                    pair_string, i)
                variables[index] = np.concatenate((
                    np.expand_dims(abs_img[:, i] *
                                   (subtract_img[:, i] > 0) * 255,
                                   axis=1),
                    np.expand_dims(abs_img[:, i] * 0.0, axis=1),
                    np.expand_dims(abs_img[:, i] *
                                   (subtract_img[:, i] < 0) * 255,
                                   axis=1),
                ),
                                                  axis=1).astype(np.uint8)

            # make legend image
            legend_image = np.zeros(subtract_img.shape[2:] + (3, ))
            cv2.putText(legend_image, subtract_pair[0], (0, 60),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0))
            cv2.putText(legend_image, subtract_pair[1], (0, 120),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255))
            variables['NchImageVisualizer.legend.' + pair_string] = np.repeat(
                np.expand_dims(np.transpose(legend_image, (2, 1, 0)), axis=0),
                subtract_img.shape[0],
                axis=0)
        return super().render(variables)
Beispiel #2
0
    def train(self):
        with tqdm.tqdm(total=self.n_max_train_iter) as pbar:
            for i, batch in enumerate(self.train_iter):
                self.train_iteration = i
                variables = {}
                variables['__iteration__'] = i
                variables['__train_iteration__'] = self.train_iteration

                input_vars = self.batch_to_vars(batch)

                # Inference current batch.
                for stage_input in input_vars:
                    self.inference(stage_input, is_train=True)
                sleep(1e-3)

                # Back propagation and update network
                self.network.update()

                # Update variables.
                variables.update(self.network.variables)
                self.network.variables.clear()

                # Save network architecture
                if self.train_iteration == 0:
                    self.write_network_architecture(
                        self.architecture_loss[0],
                        variables[self.architecture_loss[1]]
                    )

                # Update variables and unwrapping chainer variable
                for var_name, value in variables.items():
                    variables[var_name] = utils.unwrapped(value)
                variables.update({'train.' + name: utils.unwrapped(value)
                                  for name, value in variables.items()})

                # validation if current iteraiton is multiplier as n_valid_step
                valid_keys = []
                if i % self.n_valid_step == 0:
                    valid_variables = self.validate(variables=variables)
                    variables.update(
                        {'valid.' + name: value for name, value in valid_variables.items()})
                    self.network.variables.clear()
                    del valid_variables

                # Write log
                for logger in self.logger:
                    logger(variables, is_valid=False)

                # Update progress bar
                self.print_description(pbar, variables)
                pbar.update()
                if self.n_max_train_iter <= i:
                    break

                # Refresh variables
                variables.clear()
                gc.collect()
Beispiel #3
0
def blend_image(*image_pairs):
    result_image = []
    for fore, back in zip(image_pairs[::2], image_pairs[1::2]):
        fore = utils.unwrapped(fore)
        back = utils.unwrapped(back)
        if back.shape[1] > fore.shape[1]:
            fore = F.repeat(fore, back.shape[1] // fore.shape[1], axis=1)
        elif back.shape[1] < fore.shape[1]:
            back = F.repeat(back, fore.shape[1] // back.shape[1], axis=1)
        result_image.append(normalize(fore) * 0.5 + normalize(back) * 0.5)
    return result_image
Beispiel #4
0
def normalize(*images):
    results = []
    for img in images:
        img = utils.unwrapped(img)
        amax = np.amax(img)
        amin = np.amin(img)
        results.append((img - amax) / (amax - amin + 1e-8))
    return results[0] if len(results) == 1 else results
Beispiel #5
0
    def peek(self, variable):
        patch_image = variable[self.patch]
        case_name = variable[self.case_name]
        crop_region = variable[self.crop_region]
        source_volume = variable[
            self.source_volume] if self.source_volume is not None else [
                None
            ] * patch_image

        for patch, case_name, crop_region, source_volume in zip(
                patch_image, case_name, crop_region, source_volume):
            if case_name not in self.images:
                image_shape = self.image_shape if self.image_shape is not None else source_volume.shape
                self.images[case_name] = np.zeros(image_shape, np.float32)
            self.images[case_name][tuple(
                (slice(*c) for c in crop_region))] = utils.unwrapped(patch)
Beispiel #6
0
    def __call__(self, variables, is_valid=False):
        dump_vars = OrderedDict()
        appendix_vars = OrderedDict()

        for var_name, weight in zip_longest(self.dump_variables, self.weights):
            if var_name not in variables:
                dump_vars[var_name] = ''
                appendix_vars['raw.{}'.format(var_name)] = ''
                appendix_vars['weight.{}'.format(var_name)] = ''
                continue

            if isinstance(weight, str):
                if weight == '':
                    weight = None
                elif weight in variables:
                    member = weight.split('.')
                    var_name = member[0]
                    fields = member[1:]
                    weight = get_field(variables[var_name], fields)
                else:
                    member = weight.split('.')
                    name = member[0]
                    fields = member[1:]
                    try:
                        process = deepnet.core.network.build.get_process(name)
                        weight = get_field(process, fields)
                        weight = utils.unwrapped(weight)
                    except:
                        weight = None

            if weight is not None:
                dump_vars[var_name] = variables[var_name] * float(weight)
                appendix_vars['raw.{}'.format(var_name)] = variables[var_name]
                appendix_vars['weight.{}'.format(var_name)] = float(weight)
            else:
                dump_vars[var_name] = variables[var_name]

        dump_vars.update(appendix_vars)
        self.dump(dump_vars)