def forward_batch(self, model, data_batch, mode, optimizer, criterion): imgsize = self.config['imgsize'] is_train = mode == 'train' report_image_freq = self.config['report_image_freq'] temporal_encode = self.config['temporal_encode'] thickness = self.config['thickness'] # Unpack data batch points = data_batch['points3'].to(self.device) category = data_batch['category'].to(self.device) if temporal_encode: intensities = data_batch['intensities'] else: intensities = 1.0 # Rasterization sketches_image = Raster.to_image(points, intensities, imgsize, thickness, device=self.device) # Forward if is_train: optimizer.zero_grad() with torch.set_grad_enabled(is_train): logits = model(sketches_image.repeat(1, 3, 1, 1)) loss = criterion(logits, category) if is_train: loss.backward() optimizer.step() if report_image_freq > 0 and self.step_counters[mode] % report_image_freq == 0: image_grid = torchvision.utils.make_grid(sketches_image, nrow=4) self.reporter.add_image('{}/sketch_input'.format(mode), image_grid, self.step_counters[mode]) return logits, loss, category
def forward_sample(self, model, batch_data, index, drawing_ratio): imgsize = self.config['imgsize'] thickness = self.config['thickness'] points = batch_data['points3'].to(self.device) category = batch_data['category'].to(self.device) if self.config['temporal_encode']: intensities = batch_data['intensities'] else: intensities = 1.0 start_time = time.time() # Rasterization sketches_image = Raster.to_image(points, intensities, imgsize, thickness, device=self.device) logits = model(sketches_image.repeat(1, 3, 1, 1)) duration = time.time() - start_time # _, predicts = torch.max(logits, 1) # if drawing_ratio == 1: # self.collect_stats.append((index, predicts.cpu().numpy()[0], batch_data['category'][0])) return logits, category, duration
def forward_batch(self, model, data_batch, mode, optimizer, criterion): imgsize = self.config['imgsize'] is_train = mode == 'train' report_hist_freq = self.config['report_hist_freq'] report_image_freq = self.config['report_image_freq'] thickness = self.config['thickness'] # Unpack data batch points = data_batch['points3'].to(self.device) points_offset = data_batch['points3_offset'].to(self.device) points_length = data_batch['points3_length'] category = data_batch['category'].to(self.device) # Original input without offsetting if report_image_freq > 0 and self.step_counters[mode] % report_image_freq == 0: images = Raster.to_image(points, 1.0, imgsize, thickness, device=self.device) image_grid = torchvision.utils.make_grid(images, nrow=4) self.reporter.add_image('{}/sketch_input'.format(mode), image_grid, self.step_counters[mode]) # Forward if is_train: optimizer.zero_grad() with torch.set_grad_enabled(is_train): logits, attention, images = model(points, points_offset, points_length) loss = criterion(logits, category) # Backward if is_train: loss.backward() optimizer.step() # if report_image_freq > 0 and self.step_counters[mode] % report_image_freq == 0: # image_grid = torchvision.utils.make_grid(images, nrow=4) # self.reporter.add_image('{}/sketch_attention'.format(mode), # image_grid, # self.step_counters[mode]) if is_train and report_hist_freq > 0 and self.step_counters[mode] % report_hist_freq == 0: self.reporter.add_histogram('{}/attention'.format(mode), attention, self.step_counters[mode], bins='auto') self.reporter.add_histogram('{}/points_length'.format(mode), points_length, self.step_counters[mode], bins='auto') return logits, loss, category
def __call__(self, points, points_offset, lengths): # === RNN === # Compute point-wise attention intensities, seqfeat = self.rnn(points_offset, lengths) # Rasterization and inject into CNN after stage 2 attention = RasterIntensityFunc.apply(points, intensities, 56, 0.5, self.eps, self.device) # === CNN === images = Raster.to_image(points, 1.0, self.img_size, self.thickness, device=self.device) cnnfeat = self.cnn(images.repeat(1, 3, 1, 1), attention) logits = self.fc(cnnfeat) return logits, (images, attention)
def __call__(self, points, points_offset, lengths): # === RNN === batch_size = points_offset.shape[0] # [batch_size, num_points, 3] # Pack points_offset_packed = pack_padded_sequence( points_offset, lengths, batch_first=self.rnn_batch_first) hiddens_packed, (last_hidden, _) = self.rnn(points_offset_packed) last_hidden = last_hidden.view(batch_size, -1) # === CNN === images = Raster.to_image(points, 1.0, self.img_size, self.thickness, device=self.device) cnnfeat = self.cnn(images.repeat(1, 3, 1, 1)) # === FC === logits = self.fc2( F.relu(self.fc1(torch.cat((last_hidden, cnnfeat), 1)))) return logits, images