def train_step(engine, batch): x = convert_tensor(batch[0], device, non_blocking=True) y = convert_tensor(batch[1], device, non_blocking=True) optimizer.zero_grad() # Runs the forward pass with autocasting. with autocast(): y_pred = model(x) loss = criterion(y_pred, y) # Scales loss. Calls backward() on scaled loss to create scaled gradients. # Backward passes under autocast are not recommended. # Backward ops run in the same precision that autocast used for corresponding forward ops. scaler.scale(loss).backward() # scaler.step() first unscales the gradients of the optimizer's assigned params. # If these gradients do not contain infs or NaNs, optimizer.step() is then called, # otherwise, optimizer.step() is skipped. scaler.step(optimizer) # Updates the scale for next iteration. scaler.update() return loss.item()
def train_step(engine, batch): x = convert_tensor(batch[0], device, non_blocking=True) y = convert_tensor(batch[1], device, non_blocking=True) optimizer.zero_grad() y_pred = model(x) loss = criterion(y_pred, y) loss.backward() optimizer.step() return loss.item()
def prepare_batch(batch, device, non_blocking): input_data, class_target_data, recon_target_data, _ = batch input_data = ie.convert_tensor(input_data, device=device, non_blocking=non_blocking) class_target_data = ie.convert_tensor(class_target_data, device=device, non_blocking=non_blocking) recon_target_data = ie.convert_tensor(recon_target_data, device=device, non_blocking=non_blocking) return input_data, class_target_data, recon_target_data
def train_step(engine, batch): x = convert_tensor(batch[0], device, non_blocking=True) y = convert_tensor(batch[1], device, non_blocking=True) optimizer.zero_grad() y_pred = model(x) loss = criterion(y_pred, y) # Runs the forward pass with autocasting. with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() return loss.item()
def prepare_batch(batch, device=None, non_blocking=False): def components(tensor: torch.Tensor): batch_size, num_channels, num_features = tensor.shape if num_channels == 3: x = tensor[:, 0, :].reshape(-1) y = tensor[:, 1, :].reshape(-1) z = tensor[:, 2, :].reshape(-1) f = torch.sqrt(x**2 + y**2 + z**2) f[torch.isnan(f)] = 0.0 h = torch.sqrt(x**2 + y**2) h[torch.isnan(h)] = 0.0 i = torch.asin(z / f) i[torch.isnan(i)] = 0.0 d = torch.asin(y / h) d[torch.isnan(d)] = 0.0 return torch.cat(tuple( map(lambda it: it.reshape(batch_size, 1, num_features), (x, y, z, f, h, i, d))), dim=1) else: return batch y = batch.pop(0) if len(y.shape) == 3: y = engine.convert_tensor(y[:, :, 0].reshape(-1), device, non_blocking) elif len(y.shape) == 2: y = engine.convert_tensor(y[:, 0], device, non_blocking) else: y = engine.convert_tensor(y, device, non_blocking) x = tuple( map( lambda it: engine.convert_tensor(components(it[1]), device, non_blocking), sorted(batch.items(), key=lambda it: it[0]))) return x, y