def train_float(model, loader, optimizer, loss_function, epoch, device, log_interval=20, tb_logger=None): # set the model to train mode model.train() train_loss = 0 # iterate over the batches of this epoch for batch_id, (x, y) in enumerate(loader): # move input and target to the active device (either cpu or gpu) x, y = x.float(), y.float() x, y = x.to(device), y.to(device) # zero the gradients for this iteration optimizer.zero_grad() # apply model, calculate loss and run backwards pass prediction = model(x) loss = loss_function(prediction.float(), y.float()) loss.backward() optimizer.step() train_loss += loss.item() # log to console if batch_id % log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_id * len(x), len(loader.dataset), 100. * batch_id / len(loader), loss.item())) # log to tensorboard if tb_logger is not None: step = epoch * len(loader) + batch_id tb_logger.log_scalar(tag='train_loss', value=loss.item(), step=step) log_image_interval = tb_logger.log_image_interval if step % log_image_interval == 0: # we always log the last validation images: img_indx, channel, size_z, size_y, size_z = x.shape single_tomo_shape = (1, 1, size_z, size_y, size_z) # we log four slices per cube: for slice_index in range(4): slice_index *= size_z // 4 tb_logger.log_image(tag='val_input', image=actions.crop_tensor( x, single_tomo_shape)[ 0, 0, slice_index].to('cpu'), step=step) tb_logger.log_image(tag='val_target', image=actions.crop_tensor( y, single_tomo_shape)[ 0, 0, slice_index].to('cpu'), step=step) tb_logger.log_image(tag='val_prediction', image= actions.crop_tensor( prediction, single_tomo_shape)[ 0, 0, slice_index].to( 'cpu'), step=step) train_loss /= len(loader) if tb_logger is not None: step = epoch * len(loader) tb_logger.log_scalar(tag='Average_train_loss', value=train_loss, step=step)
def _crop_and_concat(self, from_decoder, from_encoder): cropped = actions.crop_tensor(from_encoder, from_decoder.shape) return torch.cat((cropped, from_decoder), dim=1)