Beispiel #1
0
    def inference(args, epoch, data_loader, logger, model, offset=0):

        model.eval()

        if args.save_flow or args.render_validation:
            flow_folder = "{}/{}.epoch-{}-flow-field".format(
                args.inference_dir, args.name.replace('/', '.'), epoch)
            rendered_flow_folder = "{}/{}.epoch-{}-rendered-flow-field".format(
                args.inference_dir, args.name.replace('/', '.'), epoch)
            if not os.path.exists(flow_folder):
                os.makedirs(flow_folder)
            if not os.path.exists(rendered_flow_folder):
                os.makedirs(rendered_flow_folder)

        args.inference_n_batches = np.inf if args.inference_n_batches < 0 else args.inference_n_batches

        progress = tqdm(data_loader,
                        ncols=100,
                        total=np.minimum(len(data_loader),
                                         args.inference_n_batches),
                        desc='Inferencing ',
                        leave=True,
                        position=offset)

        statistics = []
        total_loss = 0
        for batch_idx, (data, target) in enumerate(progress):
            if args.cuda:
                data, target = [d.cuda(async=True) for d in data
                                ], [t.cuda(async=True) for t in target]
            data, target = [Variable(d, volatile=True) for d in data
                            ], [Variable(t, volatile=True) for t in target]

            # when ground-truth flows are not available for inference_dataset,
            # the targets are set to all zeros. thus, losses are actually L1 or L2 norms of compute optical flows,
            # depending on the type of loss norm passed in
            losses, output = model(data[0], target[0], inference=True)

            losses = [torch.mean(loss_value) for loss_value in losses]
            loss_val = losses[0]  # Collect first loss for weight update
            total_loss += loss_val.data[0]
            loss_values = [v.data[0] for v in losses]

            # gather loss_labels, direct return leads to recursion limit error as it looks for variables to gather'
            loss_labels = list(model.module.loss.loss_labels)

            statistics.append(loss_values)
            # import IPython; IPython.embed()
            if args.save_flow or args.render_validation:
                for i in range(args.inference_batch_size):
                    _pflow = output[i].data.cpu().numpy().transpose(1, 2, 0)
                    ground_truth = target[0][i].data.cpu().numpy().transpose(
                        1, 2, 0)
                    render_img = tools.flow_to_image(_pflow).transpose(2, 0, 1)
                    true_img = tools.flow_to_image(ground_truth).transpose(
                        2, 0, 1)
                    render_img = torch.Tensor(render_img) / 255.0
                    true_img = torch.Tensor(true_img) / 255.0
                    input_img = data[0][i, :, 0, :, :].data.cpu() / 255.0
                    logger.add_image('renderimg',
                                     torchvision.utils.make_grid(render_img),
                                     batch_idx * args.inference_batch_size + i)
                    logger.add_image('ground_truth',
                                     torchvision.utils.make_grid(true_img),
                                     batch_idx * args.inference_batch_size + i)
                    logger.add_image('input_img',
                                     torchvision.utils.make_grid(input_img),
                                     batch_idx * args.inference_batch_size + i)
                    if args.save_flow:
                        scipy.misc.imsave(
                            join(
                                rendered_flow_folder, '%06d.png' %
                                (batch_idx * args.inference_batch_size + i)),
                            render_img.numpy().transpose(1, 2, 0))
                        flow_utils.writeFlow(
                            join(
                                flow_folder, '%06d.flo' %
                                (batch_idx * args.inference_batch_size + i)),
                            _pflow)

            progress.set_description(
                'Inference Averages for Epoch {}: '.format(epoch) +
                tools.format_dictionary_of_losses(
                    loss_labels,
                    np.array(statistics).mean(axis=0)))
            progress.update(1)

            if batch_idx == (args.inference_n_batches - 1):
                break

        progress.close()

        return
Beispiel #2
0
    def forward(self, f, b, mask=None):
        """ Contextual attention layer implementation.
        Contextual attention is first introduced in publication:
            Generative Image Inpainting with Contextual Attention, Yu et al.
        Args:
            f: Input feature to match (foreground).
            b: Input feature for match (background).
            mask: Input mask for b, indicating patches not available.
            ksize: Kernel size for contextual attention.
            stride: Stride for extracting patches from b.
            rate: Dilation for matching.
            softmax_scale: Scaled softmax for attention.
        Returns:
            torch.tensor: output
        """
        # get shapes
        raw_int_fs = list(f.size())  # b*c*h*w
        raw_int_bs = list(b.size())  # b*c*h*w

        # extract patches from background with stride and rate
        kernel = 2 * self.rate
        # raw_w is extracted for reconstruction
        raw_w = extract_image_patches(b,
                                      ksizes=[kernel, kernel],
                                      strides=[self.rate,
                                               self.rate])  # b*hw*c*k*k
        raw_w_groups = torch.split(raw_w, 1, dim=0)

        # downscaling foreground option: downscaling both foreground and
        # background for matching and use original background for reconstruction.
        f = F.interpolate(f, scale_factor=1 / self.rate, mode='nearest')
        b = F.interpolate(b, scale_factor=1 / self.rate, mode='nearest')
        int_fs = list(f.size())  # b*c*h*w
        int_bs = list(b.size())
        f_groups = torch.split(
            f, 1, dim=0)  # split tensors along the batch dimension

        w = extract_image_patches(b,
                                  ksizes=[self.ksize, self.ksize],
                                  strides=[self.stride,
                                           self.stride])  # b*hw*c*k*k
        w_groups = torch.split(w, 1, dim=0)

        # process mask
        if mask is None:
            mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]])
            if self.use_cuda:
                mask = mask.cuda()
        else:
            mask = F.interpolate(mask,
                                 scale_factor=1. / (4. * self.rate),
                                 mode='nearest')
        m_groups = extract_image_patches(mask,
                                         ksizes=[self.ksize, self.ksize],
                                         strides=[self.stride,
                                                  self.stride])  # b*hw*c*k*k

        # m = m[0]  # hw*c*k*k
        # m = reduce_mean(m, axis=[1, 2, 3])  # hw*1*1*1
        # m = m.permute(1, 0, 2, 3).contiguous()  # 1*hw*1*1
        # mm = (m==0).to(torch.float32)   # 1*hw*1*1

        y = []
        offsets = []
        k = self.fuse_k
        scale = self.softmax_scale * 255  # to fit the PyTorch tensor image value range
        fuse_weight = torch.eye(k).view(1, 1, k, k)  # 1*1*k*k
        if self.use_cuda:
            fuse_weight = fuse_weight.cuda()

        for xi, wi, raw_wi, mi in zip(f_groups, w_groups, raw_w_groups,
                                      m_groups):
            '''
            O => output channel as a conv filter
            I => input channel as a conv filter
            xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32)
            wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3)
            raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4)
            '''
            # conv for compare
            escape_NaN = torch.FloatTensor([1e-4])
            if self.use_cuda:
                escape_NaN = escape_NaN.cuda()
            wi = wi[0]  # hw*c*k*k
            wi_normed = wi / torch.max(
                torch.sqrt(reduce_sum(torch.pow(wi, 2), axis=[1, 2, 3])),
                escape_NaN)
            xi_normed = same_padding(xi, [self.ksize, self.ksize],
                                     [1, 1])  # xi: 1*c*H*W
            yi = F.conv2d(xi_normed, wi_normed, stride=1)  # 1*hw*H*W

            # conv implementation for fuse scores to encourage large patches
            if self.fuse:
                # make all of depth to spatial resolution
                yi = yi.view(1, 1, int_bs[2] * int_bs[3], int_fs[2] *
                             int_fs[3])  # (B=1, I=1, H=32*32, W=32*32)
                yi = same_padding(yi, [k, k], [1, 1])
                yi = F.conv2d(yi, fuse_weight,
                              stride=1)  # (B=1, C=1, H=32*32, W=32*32)

                yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2],
                                          int_fs[3])  # (B=1, 32, 32, 32, 32)
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = yi.contiguous().view(1, 1, int_bs[2] * int_bs[3],
                                          int_fs[2] * int_fs[3])
                yi = same_padding(yi, [k, k], [1, 1])
                yi = F.conv2d(yi, fuse_weight, stride=1)
                yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3],
                                          int_fs[2])
                yi = yi.permute(0, 2, 1, 4, 3)
                yi = yi.contiguous().view(
                    1, int_bs[2] * int_bs[3], int_fs[2],
                    int_fs[3])  # (B=1, C=32*32, H=32, W=32)

            # mi: hw*c*k*k
            mi = reduce_mean(mi, axis=[1, 2, 3])  # hw*1*1*1
            mi = mi.permute(1, 0, 2, 3).contiguous()  # 1*hw*1*1
            mm = (mi == 0).to(torch.float32)  # 1*hw*1*1

            # softmax to match
            yi = yi * mm
            yi = F.softmax(yi * scale, dim=1)
            yi = yi * mm  # 1*hw*H*W

            offset = torch.argmax(yi, dim=1, keepdim=True)  # 1*1*H*W
            if int_bs != int_fs:
                # Normalize the offset value to match foreground dimension
                times = float(int_fs[2] * int_fs[3]) / float(
                    int_bs[2] * int_bs[3])
                offset = ((offset + 1).float() * times - 1).to(torch.int64)
            offset = torch.cat([offset // int_fs[3], offset % int_fs[3]],
                               dim=1)  # 1*2*H*W

            # deconv for patch pasting
            wi_center = raw_wi[0]
            yi = F.conv_transpose2d(yi, wi_center, stride=self.rate,
                                    padding=1) / 4.  # (B=1, C=128, H=64, W=64)
            y.append(yi)
            offsets.append(offset)

        y = torch.cat(y, dim=0)  # back to the mini-batch
        y.contiguous().view(raw_int_fs)

        offsets = torch.cat(offsets, dim=0)
        offsets = offsets.view(int_fs[0], 2, *int_fs[2:])

        # case1: visualize optical flow: minus current position
        h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand(
            int_fs[0], -1, -1, int_fs[3])
        w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand(
            int_fs[0], -1, int_fs[2], -1)
        ref_coordinate = torch.cat([h_add, w_add], dim=1)  # b*2*H*W
        if self.use_cuda:
            ref_coordinate = ref_coordinate.cuda()

        offsets = offsets - ref_coordinate
        # flow = pt_flow_to_image(offsets)

        flow = torch.from_numpy(
            flow_to_image(offsets.permute(0, 2, 3,
                                          1).cpu().data.numpy())) / 255.
        flow = flow.permute(0, 3, 1, 2)
        if self.use_cuda:
            flow = flow.cuda()
        # case2: visualize which pixels are attended
        # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy()))

        if self.rate != 1:
            flow = F.interpolate(flow,
                                 scale_factor=self.rate * 4,
                                 mode='nearest')

        return y, flow
Beispiel #3
0
    def train(args,
              epoch,
              start_iteration,
              data_loader,
              model,
              optimizer,
              logger,
              is_validate=False,
              offset=0):
        statistics = []
        total_loss = 0

        if is_validate:
            model.eval()
            title = 'Validating Epoch {}'.format(epoch)
            args.validation_n_batches = len(
                data_loader
            ) - 1 if args.validation_n_batches < 0 else args.validation_n_batches
            progress = tqdm(tools.IteratorTimer(data_loader),
                            ncols=100,
                            total=np.minimum(len(data_loader),
                                             args.validation_n_batches),
                            leave=True,
                            position=offset,
                            desc=title)
        else:
            model.train()
            title = 'Training Epoch {}'.format(epoch)
            args.train_n_batches = len(
                data_loader
            ) - 1 if args.train_n_batches < 0 else args.train_n_batches
            progress = tqdm(tools.IteratorTimer(data_loader),
                            ncols=120,
                            total=np.minimum(len(data_loader),
                                             args.train_n_batches),
                            smoothing=.9,
                            miniters=1,
                            leave=True,
                            position=offset,
                            desc=title)

        last_log_time = progress._time()
        for batch_idx, (data, target) in enumerate(progress):

            data, target = [Variable(d, volatile=is_validate) for d in data], [
                Variable(t, volatile=is_validate) for t in target
            ]
            if args.cuda and args.number_gpus == 1:
                data, target = [d.cuda(async=True) for d in data
                                ], [t.cuda(async=True) for t in target]

            optimizer.zero_grad() if not is_validate else None
            losses = model(data[0], target[0])
            losses = [torch.mean(loss_value) for loss_value in losses]
            loss_val = losses[1]  # Collect first loss for weight update
            total_loss += loss_val.data[0]
            loss_values = [v.data[0] for v in losses]

            # gather loss_labels, direct return leads to recursion limit error as it looks for variables to gather'
            loss_labels = list(model.module.loss.loss_labels)

            assert not np.isnan(total_loss)

            if not is_validate and args.fp16:
                loss_val.backward()
                if args.gradient_clip:
                    torch.nn.utils.clip_grad_norm(model.parameters(),
                                                  args.gradient_clip)

                params = list(model.parameters())
                for i in range(len(params)):
                    param_copy[i].grad = params[i].grad.clone().type_as(
                        params[i]).detach()
                    param_copy[i].grad.mul_(1. / args.loss_scale)
                optimizer.step()
                for i in range(len(params)):
                    params[i].data.copy_(param_copy[i].data)

            elif not is_validate:
                loss_val.backward()
                if args.gradient_clip:
                    torch.nn.utils.clip_grad_norm(model.parameters(),
                                                  args.gradient_clip)
                optimizer.step()

            # Update hyperparameters if needed
            global_iteration = start_iteration + batch_idx
            if not is_validate:
                tools.update_hyperparameter_schedule(args, epoch,
                                                     global_iteration,
                                                     optimizer)
                loss_labels.append('lr')
                loss_values.append(optimizer.param_groups[0]['lr'])

            loss_labels.append('load')
            loss_values.append(progress.iterable.last_duration)

            # Print out statistics
            statistics.append(loss_values)
            title = '{} Epoch {}'.format(
                'Validating' if is_validate else 'Training', epoch)

            if (type(loss_labels[0]) is list) or (type(loss_labels[0]) is
                                                  tuple):
                progress.set_description(title + ' ' +
                                         tools.format_dictionary_of_losses(
                                             loss_labels[0], statistics[-1]))
            else:
                progress.set_description(title + ' ' +
                                         tools.format_dictionary_of_losses(
                                             loss_labels, statistics[-1]))

            if ((((global_iteration + 1) % args.log_frequency) == 0
                 and not is_validate) or
                (is_validate and batch_idx == args.validation_n_batches - 1)):

                global_iteration = global_iteration if not is_validate else start_iteration

                logger.add_scalar(
                    'batch logs per second',
                    len(statistics) / (progress._time() - last_log_time),
                    global_iteration)
                last_log_time = progress._time()

                all_losses = np.array(statistics)

                for i, key in enumerate(loss_labels[0] if (
                        type(loss_labels[0]) is list) or (
                            type(loss_labels[0]) is tuple) else loss_labels):
                    logger.add_scalar('average batch ' + str(key),
                                      all_losses[:,
                                                 i].mean(), global_iteration)
                    #logger.add_histogram(str(key), all_losses[:, i], global_iteration)
                if is_validate:
                    _, output = model(data[0], target[0], inference=True)
                    render_flow = output[0].data.cpu().numpy().transpose(
                        1, 2, 0)
                    ground_truth = target[0][0].data.cpu().numpy().transpose(
                        1, 2, 0)
                    render_img = tools.flow_to_image(render_flow).transpose(
                        2, 0, 1)
                    true_img = tools.flow_to_image(ground_truth).transpose(
                        2, 0, 1)
                    render_img = torch.Tensor(render_img) / 255.0
                    true_img = torch.Tensor(true_img) / 255.0
                    input_img = data[0][0, :, 0, :, :].data.cpu() / 255.0
                    logger.add_image('renderimg',
                                     torchvision.utils.make_grid(render_img),
                                     global_iteration)
                    logger.add_image('ground_truth',
                                     torchvision.utils.make_grid(true_img),
                                     global_iteration)
                    logger.add_image('input_img',
                                     torchvision.utils.make_grid(input_img),
                                     global_iteration)

            # Reset Summary
            statistics = []

            if (is_validate and (batch_idx == args.validation_n_batches)):
                break

            if ((not is_validate) and (batch_idx == (args.train_n_batches))):
                break

        progress.close()

        return total_loss / float(batch_idx + 1), (batch_idx + 1)