コード例 #1
0
    def on_epoch_init(self, lr, train, epoch, total_epochs):
        self.epoch = epoch
        self.total_epochs = total_epochs

        if train:
            summary.scalar('train/lr', lr[0], global_step=epoch)

        # we track the first 2 epochs and then some others
        self.track_epoch = (epoch <= 2) or (epoch %
                                            self.images_every_n_epochs) == 0
        if self.track_epoch:
            if train:
                if self.keep_num_training_steps > 0:
                    self.random_step_indices = random_indices(
                        k=self.keep_num_training_steps,
                        n=len(self.training_loader))
            else:
                if self.keep_num_validation_steps > 0:
                    self.random_step_indices = random_indices(
                        k=self.keep_num_validation_steps,
                        n=len(self.validation_loader))
コード例 #2
0
    def on_step_finished(self, example_dict, model_dict, loss_dict, train, step, total_steps):
        if train:
            self.global_train_step += 1
            for key, value in loss_dict.items():
                summary.scalar('train/%s' % key, value, global_step=self.global_train_step)

        def from_basename(name, in_basename):
            return "{}_{}/{}".format(in_basename, self.prefix, name)

        # ------------------------------------------------------------------------------
        # We also track some epochs completely on a subset of simple
        # ------------------------------------------------------------------------------
        if self.track_epoch:
            global_step = self.global_train_step if train else self.epoch
            if step in self.random_step_indices:
                batch_idx = 0
                basename = example_dict['basename'][batch_idx]
                basename = basename.replace('/', '_')
                make_name = functools.partial(from_basename, in_basename=basename)

                # visualize inputs
                input1 = example_dict['input1'][batch_idx:batch_idx + 1, ...]
                input2 = example_dict['input2'][batch_idx:batch_idx + 1, ...]
                input3 = example_dict['input3'][batch_idx:batch_idx + 1, ...]
                input4 = example_dict['input4'][batch_idx:batch_idx + 1, ...]
                input5 = example_dict['input5'][batch_idx:batch_idx + 1, ...]
                summary.images(make_name('input1'), input1, global_step=global_step)
                summary.images(make_name('input2'), input2, global_step=global_step)
                summary.images(make_name('input3'), input3, global_step=global_step)
                summary.images(make_name('input4'), input4, global_step=global_step)
                summary.images(make_name('input5'), input5, global_step=global_step)

                # visualize target
                target1 = example_dict['target1'][batch_idx:batch_idx + 1, ...]
                summary.images(make_name('gt'), target1, global_step=global_step)

                # visualize output
                output1 = model_dict['output1'][batch_idx:batch_idx + 1, ...]
                summary.images(make_name('output'), output1, global_step=global_step)

                # visualize error
                b, _, h, w = output1.size()
                error1 = torch.sum((output1 - target1) ** 2, dim=1, keepdim=True)
                error1 = self.err2rgb(error1)
                summary.images(make_name('error'), error1, global_step=global_step)

                # visualize concatenated summary image
                x1 = torch.cat((input3, target1), dim=3)
                x2 = torch.cat((error1.float() / 255, output1), dim=3)
                x = torch.cat((x1, x2), dim=2)
                summary.images(make_name('summary'), self.downsample(x), global_step=global_step)

                # visualize warped images
                if 'warped1' in model_dict.keys():
                    warped1 = model_dict['warped1'][batch_idx:batch_idx + 1, ...]
                    warped2 = model_dict['warped2'][batch_idx:batch_idx + 1, ...]
                    warped4 = model_dict['warped4'][batch_idx:batch_idx + 1, ...]
                    warped5 = model_dict['warped5'][batch_idx:batch_idx + 1, ...]
                    summary.images(make_name('warped1'), warped1, global_step=global_step)
                    summary.images(make_name('warped2'), warped2, global_step=global_step)
                    summary.images(make_name('warped4'), warped4, global_step=global_step)
                    summary.images(make_name('warped5'), warped5, global_step=global_step)
コード例 #3
0
 def on_epoch_finished(self, avg_loss_dict, train, epoch, total_epochs):
     if not train:
         for key, value in avg_loss_dict.items():
             summary.scalar('valid/%s' % key, value, global_step=epoch)
コード例 #4
0
    def on_step_finished(self, example_dict, model_dict, loss_dict, train,
                         step, total_steps):

        prefix = 'train' if train else 'valid'

        def from_basename(name, in_basename):
            return "{}/{}/{}".format(prefix, in_basename, name)

        if train:
            self.global_train_step += 1
            for key, value in loss_dict.items():
                summary.scalar('train/%s' % key,
                               value,
                               global_step=self.global_train_step)

        if self.track_epoch:
            global_step = self.global_train_step if train else self.epoch
            if step in self.random_step_indices:
                basename = example_dict['basename'][0]
                basename = basename.replace('/', '_')
                make_name = functools.partial(from_basename,
                                              in_basename=basename)

                input1 = example_dict['input1']
                input2 = example_dict['input2']
                target1 = example_dict['target1']

                progress = [input1, input2]

                if train:
                    flow = model_dict['flow2']
                    target1 *= self.args.loss_div_flow
                    target1 = downsample2d_as(target1, flow)
                else:
                    flow = model_dict['flow1']
                    flow /= self.args.loss_div_flow

                u, v = target1.chunk(chunks=2, dim=1)
                rad = torch.sqrt(u**2 + v**2)
                max_flow = max2d(rad, keepdim=True)
                max_flow = max_flow.repeat(2, 1, 1, 1)

                flowim, targetim = self.flow2rgb(torch.cat((flow, target1),
                                                           dim=0),
                                                 max_flow=max_flow).chunk(
                                                     chunks=2, dim=0)
                size = tuple(flow.size()[2:4])

                progress = [
                    tf.interpolate(im,
                                   size,
                                   mode='bilinear',
                                   align_corners=True) for im in progress
                ]

                progress.append(flowim)
                progress.append(targetim)

                progress = torch.cat(progress, dim=-1)

                if progress.size(2) != 128:
                    factor = 128 / progress.size(2)
                    new_height = int(progress.size(2) * factor)
                    new_width = int(progress.size(3) * factor)
                    new_size = (new_height, new_width)
                    progress = tf.interpolate(progress,
                                              size=new_size,
                                              mode='nearest')

                progress = progress[0, ...]

                summary.image(make_name('progress'),
                              progress,
                              global_step=global_step)