Exemplo n.º 1
0
    def train_step(engine, batch):
        x = convert_tensor(batch[0], device, non_blocking=True)
        y = convert_tensor(batch[1], device, non_blocking=True)

        optimizer.zero_grad()

        # Runs the forward pass with autocasting.
        with autocast():
            y_pred = model(x)
            loss = criterion(y_pred, y)

        # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
        # Backward passes under autocast are not recommended.
        # Backward ops run in the same precision that autocast used for corresponding forward ops.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

        return loss.item()
Exemplo n.º 2
0
    def train_step(engine, batch):
        x = convert_tensor(batch[0], device, non_blocking=True)
        y = convert_tensor(batch[1], device, non_blocking=True)

        optimizer.zero_grad()

        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()

        optimizer.step()

        return loss.item()
def prepare_batch(batch, device, non_blocking):
    input_data, class_target_data, recon_target_data, _ = batch
    input_data = ie.convert_tensor(input_data,
                                   device=device,
                                   non_blocking=non_blocking)
    class_target_data = ie.convert_tensor(class_target_data,
                                          device=device,
                                          non_blocking=non_blocking)
    recon_target_data = ie.convert_tensor(recon_target_data,
                                          device=device,
                                          non_blocking=non_blocking)

    return input_data, class_target_data, recon_target_data
Exemplo n.º 4
0
    def train_step(engine, batch):
        x = convert_tensor(batch[0], device, non_blocking=True)
        y = convert_tensor(batch[1], device, non_blocking=True)

        optimizer.zero_grad()

        y_pred = model(x)
        loss = criterion(y_pred, y)

        # Runs the forward pass with autocasting.
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()

        optimizer.step()

        return loss.item()
Exemplo n.º 5
0
        def prepare_batch(batch, device=None, non_blocking=False):
            def components(tensor: torch.Tensor):
                batch_size, num_channels, num_features = tensor.shape
                if num_channels == 3:
                    x = tensor[:, 0, :].reshape(-1)
                    y = tensor[:, 1, :].reshape(-1)
                    z = tensor[:, 2, :].reshape(-1)
                    f = torch.sqrt(x**2 + y**2 + z**2)
                    f[torch.isnan(f)] = 0.0
                    h = torch.sqrt(x**2 + y**2)
                    h[torch.isnan(h)] = 0.0
                    i = torch.asin(z / f)
                    i[torch.isnan(i)] = 0.0
                    d = torch.asin(y / h)
                    d[torch.isnan(d)] = 0.0
                    return torch.cat(tuple(
                        map(lambda it: it.reshape(batch_size, 1, num_features),
                            (x, y, z, f, h, i, d))),
                                     dim=1)
                else:
                    return batch

            y = batch.pop(0)
            if len(y.shape) == 3:
                y = engine.convert_tensor(y[:, :, 0].reshape(-1), device,
                                          non_blocking)
            elif len(y.shape) == 2:
                y = engine.convert_tensor(y[:, 0], device, non_blocking)
            else:
                y = engine.convert_tensor(y, device, non_blocking)
            x = tuple(
                map(
                    lambda it: engine.convert_tensor(components(it[1]), device,
                                                     non_blocking),
                    sorted(batch.items(), key=lambda it: it[0])))
            return x, y