def inference(args, epoch, data_loader, logger, model, offset=0): model.eval() if args.save_flow or args.render_validation: flow_folder = "{}/{}.epoch-{}-flow-field".format( args.inference_dir, args.name.replace('/', '.'), epoch) rendered_flow_folder = "{}/{}.epoch-{}-rendered-flow-field".format( args.inference_dir, args.name.replace('/', '.'), epoch) if not os.path.exists(flow_folder): os.makedirs(flow_folder) if not os.path.exists(rendered_flow_folder): os.makedirs(rendered_flow_folder) args.inference_n_batches = np.inf if args.inference_n_batches < 0 else args.inference_n_batches progress = tqdm(data_loader, ncols=100, total=np.minimum(len(data_loader), args.inference_n_batches), desc='Inferencing ', leave=True, position=offset) statistics = [] total_loss = 0 for batch_idx, (data, target) in enumerate(progress): if args.cuda: data, target = [d.cuda(async=True) for d in data ], [t.cuda(async=True) for t in target] data, target = [Variable(d, volatile=True) for d in data ], [Variable(t, volatile=True) for t in target] # when ground-truth flows are not available for inference_dataset, # the targets are set to all zeros. thus, losses are actually L1 or L2 norms of compute optical flows, # depending on the type of loss norm passed in losses, output = model(data[0], target[0], inference=True) losses = [torch.mean(loss_value) for loss_value in losses] loss_val = losses[0] # Collect first loss for weight update total_loss += loss_val.data[0] loss_values = [v.data[0] for v in losses] # gather loss_labels, direct return leads to recursion limit error as it looks for variables to gather' loss_labels = list(model.module.loss.loss_labels) statistics.append(loss_values) # import IPython; IPython.embed() if args.save_flow or args.render_validation: for i in range(args.inference_batch_size): _pflow = output[i].data.cpu().numpy().transpose(1, 2, 0) ground_truth = target[0][i].data.cpu().numpy().transpose( 1, 2, 0) render_img = tools.flow_to_image(_pflow).transpose(2, 0, 1) true_img = tools.flow_to_image(ground_truth).transpose( 2, 0, 1) render_img = torch.Tensor(render_img) / 255.0 true_img = torch.Tensor(true_img) / 255.0 input_img = data[0][i, :, 0, :, :].data.cpu() / 255.0 logger.add_image('renderimg', torchvision.utils.make_grid(render_img), batch_idx * args.inference_batch_size + i) logger.add_image('ground_truth', torchvision.utils.make_grid(true_img), batch_idx * args.inference_batch_size + i) logger.add_image('input_img', torchvision.utils.make_grid(input_img), batch_idx * args.inference_batch_size + i) if args.save_flow: scipy.misc.imsave( join( rendered_flow_folder, '%06d.png' % (batch_idx * args.inference_batch_size + i)), render_img.numpy().transpose(1, 2, 0)) flow_utils.writeFlow( join( flow_folder, '%06d.flo' % (batch_idx * args.inference_batch_size + i)), _pflow) progress.set_description( 'Inference Averages for Epoch {}: '.format(epoch) + tools.format_dictionary_of_losses( loss_labels, np.array(statistics).mean(axis=0))) progress.update(1) if batch_idx == (args.inference_n_batches - 1): break progress.close() return
def forward(self, f, b, mask=None): """ Contextual attention layer implementation. Contextual attention is first introduced in publication: Generative Image Inpainting with Contextual Attention, Yu et al. Args: f: Input feature to match (foreground). b: Input feature for match (background). mask: Input mask for b, indicating patches not available. ksize: Kernel size for contextual attention. stride: Stride for extracting patches from b. rate: Dilation for matching. softmax_scale: Scaled softmax for attention. Returns: torch.tensor: output """ # get shapes raw_int_fs = list(f.size()) # b*c*h*w raw_int_bs = list(b.size()) # b*c*h*w # extract patches from background with stride and rate kernel = 2 * self.rate # raw_w is extracted for reconstruction raw_w = extract_image_patches(b, ksizes=[kernel, kernel], strides=[self.rate, self.rate]) # b*hw*c*k*k raw_w_groups = torch.split(raw_w, 1, dim=0) # downscaling foreground option: downscaling both foreground and # background for matching and use original background for reconstruction. f = F.interpolate(f, scale_factor=1 / self.rate, mode='nearest') b = F.interpolate(b, scale_factor=1 / self.rate, mode='nearest') int_fs = list(f.size()) # b*c*h*w int_bs = list(b.size()) f_groups = torch.split( f, 1, dim=0) # split tensors along the batch dimension w = extract_image_patches(b, ksizes=[self.ksize, self.ksize], strides=[self.stride, self.stride]) # b*hw*c*k*k w_groups = torch.split(w, 1, dim=0) # process mask if mask is None: mask = torch.zeros([int_bs[0], 1, int_bs[2], int_bs[3]]) if self.use_cuda: mask = mask.cuda() else: mask = F.interpolate(mask, scale_factor=1. / (4. * self.rate), mode='nearest') m_groups = extract_image_patches(mask, ksizes=[self.ksize, self.ksize], strides=[self.stride, self.stride]) # b*hw*c*k*k # m = m[0] # hw*c*k*k # m = reduce_mean(m, axis=[1, 2, 3]) # hw*1*1*1 # m = m.permute(1, 0, 2, 3).contiguous() # 1*hw*1*1 # mm = (m==0).to(torch.float32) # 1*hw*1*1 y = [] offsets = [] k = self.fuse_k scale = self.softmax_scale * 255 # to fit the PyTorch tensor image value range fuse_weight = torch.eye(k).view(1, 1, k, k) # 1*1*k*k if self.use_cuda: fuse_weight = fuse_weight.cuda() for xi, wi, raw_wi, mi in zip(f_groups, w_groups, raw_w_groups, m_groups): ''' O => output channel as a conv filter I => input channel as a conv filter xi : separated tensor along batch dimension of front; (B=1, C=128, H=32, W=32) wi : separated patch tensor along batch dimension of back; (B=1, O=32*32, I=128, KH=3, KW=3) raw_wi : separated tensor along batch dimension of back; (B=1, I=32*32, O=128, KH=4, KW=4) ''' # conv for compare escape_NaN = torch.FloatTensor([1e-4]) if self.use_cuda: escape_NaN = escape_NaN.cuda() wi = wi[0] # hw*c*k*k wi_normed = wi / torch.max( torch.sqrt(reduce_sum(torch.pow(wi, 2), axis=[1, 2, 3])), escape_NaN) xi_normed = same_padding(xi, [self.ksize, self.ksize], [1, 1]) # xi: 1*c*H*W yi = F.conv2d(xi_normed, wi_normed, stride=1) # 1*hw*H*W # conv implementation for fuse scores to encourage large patches if self.fuse: # make all of depth to spatial resolution yi = yi.view(1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]) # (B=1, I=1, H=32*32, W=32*32) yi = same_padding(yi, [k, k], [1, 1]) yi = F.conv2d(yi, fuse_weight, stride=1) # (B=1, C=1, H=32*32, W=32*32) yi = yi.contiguous().view(1, int_bs[2], int_bs[3], int_fs[2], int_fs[3]) # (B=1, 32, 32, 32, 32) yi = yi.permute(0, 2, 1, 4, 3) yi = yi.contiguous().view(1, 1, int_bs[2] * int_bs[3], int_fs[2] * int_fs[3]) yi = same_padding(yi, [k, k], [1, 1]) yi = F.conv2d(yi, fuse_weight, stride=1) yi = yi.contiguous().view(1, int_bs[3], int_bs[2], int_fs[3], int_fs[2]) yi = yi.permute(0, 2, 1, 4, 3) yi = yi.contiguous().view( 1, int_bs[2] * int_bs[3], int_fs[2], int_fs[3]) # (B=1, C=32*32, H=32, W=32) # mi: hw*c*k*k mi = reduce_mean(mi, axis=[1, 2, 3]) # hw*1*1*1 mi = mi.permute(1, 0, 2, 3).contiguous() # 1*hw*1*1 mm = (mi == 0).to(torch.float32) # 1*hw*1*1 # softmax to match yi = yi * mm yi = F.softmax(yi * scale, dim=1) yi = yi * mm # 1*hw*H*W offset = torch.argmax(yi, dim=1, keepdim=True) # 1*1*H*W if int_bs != int_fs: # Normalize the offset value to match foreground dimension times = float(int_fs[2] * int_fs[3]) / float( int_bs[2] * int_bs[3]) offset = ((offset + 1).float() * times - 1).to(torch.int64) offset = torch.cat([offset // int_fs[3], offset % int_fs[3]], dim=1) # 1*2*H*W # deconv for patch pasting wi_center = raw_wi[0] yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4. # (B=1, C=128, H=64, W=64) y.append(yi) offsets.append(offset) y = torch.cat(y, dim=0) # back to the mini-batch y.contiguous().view(raw_int_fs) offsets = torch.cat(offsets, dim=0) offsets = offsets.view(int_fs[0], 2, *int_fs[2:]) # case1: visualize optical flow: minus current position h_add = torch.arange(int_fs[2]).view([1, 1, int_fs[2], 1]).expand( int_fs[0], -1, -1, int_fs[3]) w_add = torch.arange(int_fs[3]).view([1, 1, 1, int_fs[3]]).expand( int_fs[0], -1, int_fs[2], -1) ref_coordinate = torch.cat([h_add, w_add], dim=1) # b*2*H*W if self.use_cuda: ref_coordinate = ref_coordinate.cuda() offsets = offsets - ref_coordinate # flow = pt_flow_to_image(offsets) flow = torch.from_numpy( flow_to_image(offsets.permute(0, 2, 3, 1).cpu().data.numpy())) / 255. flow = flow.permute(0, 3, 1, 2) if self.use_cuda: flow = flow.cuda() # case2: visualize which pixels are attended # flow = torch.from_numpy(highlight_flow((offsets * mask.long()).cpu().data.numpy())) if self.rate != 1: flow = F.interpolate(flow, scale_factor=self.rate * 4, mode='nearest') return y, flow
def train(args, epoch, start_iteration, data_loader, model, optimizer, logger, is_validate=False, offset=0): statistics = [] total_loss = 0 if is_validate: model.eval() title = 'Validating Epoch {}'.format(epoch) args.validation_n_batches = len( data_loader ) - 1 if args.validation_n_batches < 0 else args.validation_n_batches progress = tqdm(tools.IteratorTimer(data_loader), ncols=100, total=np.minimum(len(data_loader), args.validation_n_batches), leave=True, position=offset, desc=title) else: model.train() title = 'Training Epoch {}'.format(epoch) args.train_n_batches = len( data_loader ) - 1 if args.train_n_batches < 0 else args.train_n_batches progress = tqdm(tools.IteratorTimer(data_loader), ncols=120, total=np.minimum(len(data_loader), args.train_n_batches), smoothing=.9, miniters=1, leave=True, position=offset, desc=title) last_log_time = progress._time() for batch_idx, (data, target) in enumerate(progress): data, target = [Variable(d, volatile=is_validate) for d in data], [ Variable(t, volatile=is_validate) for t in target ] if args.cuda and args.number_gpus == 1: data, target = [d.cuda(async=True) for d in data ], [t.cuda(async=True) for t in target] optimizer.zero_grad() if not is_validate else None losses = model(data[0], target[0]) losses = [torch.mean(loss_value) for loss_value in losses] loss_val = losses[1] # Collect first loss for weight update total_loss += loss_val.data[0] loss_values = [v.data[0] for v in losses] # gather loss_labels, direct return leads to recursion limit error as it looks for variables to gather' loss_labels = list(model.module.loss.loss_labels) assert not np.isnan(total_loss) if not is_validate and args.fp16: loss_val.backward() if args.gradient_clip: torch.nn.utils.clip_grad_norm(model.parameters(), args.gradient_clip) params = list(model.parameters()) for i in range(len(params)): param_copy[i].grad = params[i].grad.clone().type_as( params[i]).detach() param_copy[i].grad.mul_(1. / args.loss_scale) optimizer.step() for i in range(len(params)): params[i].data.copy_(param_copy[i].data) elif not is_validate: loss_val.backward() if args.gradient_clip: torch.nn.utils.clip_grad_norm(model.parameters(), args.gradient_clip) optimizer.step() # Update hyperparameters if needed global_iteration = start_iteration + batch_idx if not is_validate: tools.update_hyperparameter_schedule(args, epoch, global_iteration, optimizer) loss_labels.append('lr') loss_values.append(optimizer.param_groups[0]['lr']) loss_labels.append('load') loss_values.append(progress.iterable.last_duration) # Print out statistics statistics.append(loss_values) title = '{} Epoch {}'.format( 'Validating' if is_validate else 'Training', epoch) if (type(loss_labels[0]) is list) or (type(loss_labels[0]) is tuple): progress.set_description(title + ' ' + tools.format_dictionary_of_losses( loss_labels[0], statistics[-1])) else: progress.set_description(title + ' ' + tools.format_dictionary_of_losses( loss_labels, statistics[-1])) if ((((global_iteration + 1) % args.log_frequency) == 0 and not is_validate) or (is_validate and batch_idx == args.validation_n_batches - 1)): global_iteration = global_iteration if not is_validate else start_iteration logger.add_scalar( 'batch logs per second', len(statistics) / (progress._time() - last_log_time), global_iteration) last_log_time = progress._time() all_losses = np.array(statistics) for i, key in enumerate(loss_labels[0] if ( type(loss_labels[0]) is list) or ( type(loss_labels[0]) is tuple) else loss_labels): logger.add_scalar('average batch ' + str(key), all_losses[:, i].mean(), global_iteration) #logger.add_histogram(str(key), all_losses[:, i], global_iteration) if is_validate: _, output = model(data[0], target[0], inference=True) render_flow = output[0].data.cpu().numpy().transpose( 1, 2, 0) ground_truth = target[0][0].data.cpu().numpy().transpose( 1, 2, 0) render_img = tools.flow_to_image(render_flow).transpose( 2, 0, 1) true_img = tools.flow_to_image(ground_truth).transpose( 2, 0, 1) render_img = torch.Tensor(render_img) / 255.0 true_img = torch.Tensor(true_img) / 255.0 input_img = data[0][0, :, 0, :, :].data.cpu() / 255.0 logger.add_image('renderimg', torchvision.utils.make_grid(render_img), global_iteration) logger.add_image('ground_truth', torchvision.utils.make_grid(true_img), global_iteration) logger.add_image('input_img', torchvision.utils.make_grid(input_img), global_iteration) # Reset Summary statistics = [] if (is_validate and (batch_idx == args.validation_n_batches)): break if ((not is_validate) and (batch_idx == (args.train_n_batches))): break progress.close() return total_loss / float(batch_idx + 1), (batch_idx + 1)