예제 #1
0
def train_float(model, loader, optimizer, loss_function,
                epoch, device, log_interval=20, tb_logger=None):
    # set the model to train mode
    model.train()
    train_loss = 0
    # iterate over the batches of this epoch
    for batch_id, (x, y) in enumerate(loader):
        # move input and target to the active device (either cpu or gpu)
        x, y = x.float(), y.float()
        x, y = x.to(device), y.to(device)

        # zero the gradients for this iteration
        optimizer.zero_grad()

        # apply model, calculate loss and run backwards pass
        prediction = model(x)
        loss = loss_function(prediction.float(), y.float())
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

        # log to console
        if batch_id % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_id * len(x),
                len(loader.dataset),
                       100. * batch_id / len(loader), loss.item()))

        # log to tensorboard
        if tb_logger is not None:
            step = epoch * len(loader) + batch_id
            tb_logger.log_scalar(tag='train_loss', value=loss.item(), step=step)

            log_image_interval = tb_logger.log_image_interval
            if step % log_image_interval == 0:
                # we always log the last validation images:
                img_indx, channel, size_z, size_y, size_z = x.shape
                single_tomo_shape = (1, 1, size_z, size_y, size_z)
                # we log four slices per cube:
                for slice_index in range(4):
                    slice_index *= size_z // 4
                    tb_logger.log_image(tag='val_input',
                                        image=actions.crop_tensor(
                                            x, single_tomo_shape)[
                                            0, 0, slice_index].to('cpu'),
                                        step=step)
                    tb_logger.log_image(tag='val_target',
                                        image=actions.crop_tensor(
                                            y, single_tomo_shape)[
                                            0, 0, slice_index].to('cpu'),
                                        step=step)
                    tb_logger.log_image(tag='val_prediction',
                                        image=
                                        actions.crop_tensor(
                                            prediction, single_tomo_shape)[
                                            0, 0, slice_index].to(
                                            'cpu'),
                                        step=step)


    train_loss /= len(loader)
    if tb_logger is not None:
        step = epoch * len(loader)
        tb_logger.log_scalar(tag='Average_train_loss', value=train_loss,
                             step=step)
예제 #2
0
 def _crop_and_concat(self, from_decoder, from_encoder):
     cropped = actions.crop_tensor(from_encoder, from_decoder.shape)
     return torch.cat((cropped, from_decoder), dim=1)