Beispiel #1
0
def proto_net_episode(model: Module, optimiser: Optimizer, loss_fn: Callable,
                      input_ids: torch.Tensor, attention_mask: torch.Tensor,
                      y: torch.Tensor, n_shot: int, k_way: int, q_queries: int,
                      distance: str, train: bool):
    """Performs a single training episode for a Prototypical Network.

    # Arguments
        model: Prototypical Network to be trained.
        optimiser: Optimiser to calculate gradient step
        loss_fn: Loss function to calculate between predictions and outputs. Should be cross-entropy
        x: Input samples of few shot classification task
        y: Input labels of few shot classification task
        n_shot: Number of examples per class in the support set
        k_way: Number of classes in the few shot classification task
        q_queries: Number of examples per class in the query set
        distance: Distance metric to use when calculating distance between class prototypes and queries
        train: Whether (True) or not (False) to perform a parameter update

    # Returns
        loss: Loss of the Prototypical Network on this task
        y_pred: Predicted class probabilities for the query set on this task
    """
    if train:
        # Zero gradients
        model.train()
        optimiser.zero_grad()
    else:
        model.eval()

    # Embed all samples
    embeddings = model(input_ids, attention_mask)

    # Samples are ordered by the NShotWrapper class as follows:
    # k lots of n support samples from a particular class
    # k lots of q query samples from those classes
    support = embeddings[:n_shot * k_way]
    queries = embeddings[n_shot * k_way:]
    prototypes = compute_prototypes(support, k_way, n_shot)

    # Calculate squared distances between all queries and all prototypes
    # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
    distances = pairwise_distances(queries, prototypes, distance)

    # Calculate log p_{phi} (y = k | x)
    log_p_y = (-distances).log_softmax(dim=1)
    loss = loss_fn(log_p_y, y)

    # Prediction probabilities are softmax over distances
    y_pred = (-distances).softmax(dim=1)

    if train:
        # Take gradient step
        loss.backward()
        optimiser.step()
    else:
        pass

    return loss, y_pred
Beispiel #2
0
def update_weights(optimizer: optim.Optimizer, network: Network, batch):
    optimizer.zero_grad()

    value_loss = 0
    reward_loss = 0
    policy_loss = 0
    for image, actions, targets in batch:
        # Initial step, from the real observation.
        value, reward, policy_logits, hidden_state = network.initial_inference(
            image)
        predictions = [(1.0 / len(batch), value, reward, policy_logits)]

        # Recurrent steps, from action and previous hidden state.
        for action in actions:
            value, reward, policy_logits, hidden_state = network.recurrent_inference(
                hidden_state, action)
            # TODO: Try not scaling this for efficiency
            # Scale so total recurrent inference updates have the same weight as the on initial inference update
            predictions.append(
                (1.0 / len(actions), value, reward, policy_logits))

            hidden_state = scale_gradient(hidden_state, 0.5)

        for prediction, target in zip(predictions, targets):
            gradient_scale, value, reward, policy_logits = prediction
            target_value, target_reward, target_policy = \
                (torch.tensor(item, dtype=torch.float32, device=value.device.type) \
                for item in target)

            # Past end of the episode
            if len(target_policy) == 0:
                break

            value_loss += gradient_scale * scalar_loss(value, target_value)
            reward_loss += gradient_scale * scalar_loss(reward, target_reward)
            policy_loss += gradient_scale * cross_entropy_with_logits(
                policy_logits, target_policy)

            # print('val -------', value, target_value, scalar_loss(value, target_value))
            # print('rew -------', reward, target_reward, scalar_loss(reward, target_reward))
            # print('pol -------', policy_logits, target_policy, cross_entropy_with_logits(policy_logits, target_policy))

    value_loss /= len(batch)
    reward_loss /= len(batch)
    policy_loss /= len(batch)

    total_loss = value_loss + reward_loss + policy_loss
    scaled_loss = scale_gradient(total_loss, gradient_scale)

    logging.info('Training step {} losses'.format(network.training_steps()) + \
        ' | Total: {:.5f}'.format(total_loss) + \
        ' | Value: {:.5f}'.format(value_loss) + \
        ' | Reward: {:.5f}'.format(reward_loss) + \
        ' | Policy: {:.5f}'.format(policy_loss))

    scaled_loss.backward()
    optimizer.step()
    network.increment_step()
Beispiel #3
0
def train(loader: DataLoader, model: torch.nn.Module, criterion, optimizer: Optimizer,
          epoch: int, noise_sd: float, device: torch.device, writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs, targets = inputs.to(device), targets.to(device)
        batch_size = inputs.size(0)

        # augment inputs with noise
        inputs = inputs + torch.randn_like(inputs, device=device) * noise_sd

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), batch_size)
        top1.update(acc1.item(), batch_size)
        top5.update(acc5.item(), batch_size)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(
                epoch, i, len(loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))

    if writer:
        writer.add_scalar('loss/train', losses.avg, epoch)
        writer.add_scalar('batch_time', batch_time.avg, epoch)
        writer.add_scalar('accuracy/train@1', top1.avg, epoch)
        writer.add_scalar('accuracy/train@5', top5.avg, epoch)

    return (losses.avg, top1.avg)
Beispiel #4
0
def train_model(model: nn.Module,
                dataset: Dataset,
                batch_size: int,
                loss_function: Callable,
                optimizer: Optimizer,
                epochs: int = 1,
                loss_args: Union[dict, None] = None) -> Tuple[List, List]:
    """
    Train a model on the input dataset.
    Parameters
    ----------
    model: nn.Module
        The input model to be trained.
    dataset: torch.utils.data.Dataset
        The dataset to train on.
    batch_size: int
        The training batch size.
    loss_function: function with signature: (x, y, model, **kwargs) -> (loss, logits).
        The function used to compute the loss.
    optimizer: Optimizer
        The model's optimizer.
    epochs: int
        Number of epochs to train for. Default: 1.
    loss_args: dict or None
        Additional arguments to be passed to the loss function.

    Returns
    -------
    Tuple containing
        * losses: List[float]. The losses obtained at each step.
        * accuracies: List[float]. The accuracies obtained at each step.

    """
    if loss_args is None:
        loss_args = {}
    losses = []
    loss = 0
    accuracies = []
    num_train_batches = int(
        torch.ceil(torch.tensor(len(dataset) / batch_size)).item())
    for epoch in range(epochs):
        train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        for x, y in tqdm(iter(train_loader), total=num_train_batches):
            ##########################################################
            # YOUR CODE HERE
            optimizer.zero_grad()
            loss, logits = loss_function(x, y, model, **loss_args)
            loss.backward()
            losses.append(loss)
            optimizer.step()

            pred = logits.argmax(dim=1, keepdim=True).view(-1)
            # print(pred.shape)
            accuracy = (pred == y).float().mean().item()
            accuracies.append(accuracy)

            ##########################################################
    return losses, accuracies
 def load(self, path_to_checkpoint: str, optimizer: Optimizer = None, scheduler: _LRScheduler = None) -> 'Model':
     checkpoint = torch.load(path_to_checkpoint)
     self.load_state_dict(checkpoint['state_dict'])
     step = checkpoint['step']
     if optimizer is not None:
         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
     if scheduler is not None:
         scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
     return step
Beispiel #6
0
            def optimizer_zero_grad(self, epoch: int, batch_idx: int,
                                    optimizer: Optimizer, optimizer_idx: int):
                if optimizer_idx == 0:
                    if batch_idx % 2 == 0:
                        optimizer.zero_grad()

                if optimizer_idx == 1:
                    if batch_idx % 5 == 0:
                        optimizer.zero_grad()
Beispiel #7
0
def train_epoch(net: Net, data: SeedlingsData, epoch: int,
                normalize: transforms.Normalize, optimizer: Optimizer):
    losses = []
    train_total = 0
    train_right = 0
    for batch_index, images, labels in data.generate_train_data():
        tensor = normalize(torch.from_numpy(images))
        batch_x = Variable(tensor).cuda().float()
        batch_y = Variable(torch.from_numpy(labels)).cuda().long()

        if net.model_name == 'resnet50+':
            prob, mask, _ = remove_background(images)
            plant_area = np.sum(mask, (1, 2))
            avg_prob = np.divide(np.sum(prob * mask, (1, 2)),
                                 plant_area,
                                 out=np.zeros_like(plant_area).astype(
                                     np.float),
                                 where=plant_area != 0)
            avg_green = np.divide(np.sum(images[:, 1, :, :] * mask, (1, 2)),
                                  plant_area,
                                  out=np.zeros_like(plant_area).astype(
                                      np.float),
                                  where=plant_area != 0)
            plant_area = np.reshape(plant_area, (data.batch_size, 1))
            plant_area = Variable(torch.from_numpy(plant_area)).cuda().float()

            avg_prob = np.reshape(avg_prob, (data.batch_size, 1))
            avg_prob = Variable(torch.from_numpy(avg_prob)).cuda().float()

            avg_green = np.reshape(avg_green, (data.batch_size, 1))
            avg_green = Variable(torch.from_numpy(avg_green)).cuda().float()

            output = net(batch_x, plant_area, avg_prob, avg_green)
        else:
            output = net(batch_x)

        _, predict_batch_y = torch.max(output, 1)

        optimizer.zero_grad()
        criterion = nn.CrossEntropyLoss()
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()
        losses.append(loss.data[0])
        train_total += batch_y.size(0)
        train_right += sum(
            predict_batch_y.data.cpu().numpy() == batch_y.data.cpu().numpy())
        accuracy = train_right / train_total
        print("epoch:{}, batch index:{}, accuracy:{}, loss:{}".format(
            epoch, batch_index, accuracy, loss.data[0]))
        # Validate
        if batch_index != 0 and batch_index % 100 == 0:
            pass
    accuracy = train_right / train_total
    print("epoch:{}, , average accuracy:{}, average train loss:{}".format(
        epoch, accuracy,
        sum(losses) / len(losses)))
Beispiel #8
0
 def _convert_to_lightning_optimizer(optimizer: Optimizer) -> LightningOptimizer:
     if not isinstance(optimizer, LightningOptimizer):
         optimizer = LightningOptimizer(optimizer)  # type: ignore [assignment]
     optimizer._trainer = self
     for opt_idx, opt in enumerate(self.optimizers):
         if opt == optimizer._optimizer:
             optimizer._optimizer_idx = opt_idx
             break
     return optimizer  # type: ignore [return-value]
Beispiel #9
0
def train(args, loader: DataLoader, model: torch.nn.Module, criterion,
          optimizer: Optimizer, epoch: int, noise_sd: float):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    # top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs = inputs.cuda()
        targets = targets.cuda()

        # augment inputs with noise
        inputs = inputs + torch.randn_like(inputs, device='cuda') * noise_sd

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1 = accuracy(outputs, targets)[0]
        # acc1 = accuracy_sigmod(outputs, targets)
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        # top5.update(acc5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch,
                      i,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1))

    return (losses.avg, top1.avg)
Beispiel #10
0
def __stage_two(model: Module,
                optimiser: Optimizer,
                loss_fn: Callable,
                x: torch.Tensor,
                y: torch.Tensor,
                n_shot: int,
                k_way: int,
                q_queries: int,
                inner_train_steps: int,
                inner_lr: float,
                hebb_lr: float,
                train: bool,
                device: Union[str, torch.device],
                sys2net=lambda x, y: (0., (0., 0.)),
                sys2feat=lambda x: 1.):
    args = {'device': device, 'dtype': torch.double}
    meta_batch_size, mesa_batch_size = x.shape[0], x.shape[1]
    model.train(train)

    # actiate model.features
    model(x.reshape(meta_batch_size * mesa_batch_size, *x.shape[2:]))

    # features of shape (meta_batch_size, n*k + q*k, num_features + 1)
    features = model.features.reshape(meta_batch_size, mesa_batch_size, -1)
    features = torch.cat(
        [features, torch.ones_like(features[:, :, :1])], dim=2)
    support_features = features[:, :n_shot * k_way]
    query_features = features[:, n_shot * k_way:]

    # make support labels of shape (meta_batch_size, n*k, k)
    y = create_nshot_task_label(k_way, n_shot).to(device)
    y = y.repeat(meta_batch_size)
    y = (torch.eye(k_way, **args)[y, :] * 2 - 1) * 10
    support_y = y.reshape(meta_batch_size, n_shot * k_way, -1)

    # make query labels of shape (meta_batch_size, q*k)
    y = create_nshot_task_label(k_way, q_queries).to(device)
    query_y = y.repeat(meta_batch_size)

    # get least distance solution on support set
    weight, bias = model.output_layer.weight, model.output_layer.bias
    v = torch.cat([weight, bias.reshape(-1, 1)], dim=1)
    v = v.unsqueeze(0).repeat(meta_batch_size, 1, 1)
    w = least_dist(support_features, support_y, v)

    # compute predictions and loss for query set
    query_y_hat = torch.bmm(query_features, w.transpose(1, 2))
    query_y_hat = query_y_hat.reshape(-1, k_way)
    loss = loss_fn(query_y_hat, query_y)
    predictions = query_y_hat.softmax(dim=1)

    if train:
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

    return loss, predictions
Beispiel #11
0
def load_optimizer(optimizer: Optimizer, path: str):
    """
    Load optimizer state for resuming training

    :param optimizer:
    :param path:
    """
    optimizer.load_state_dict(torch.load(path))
    print("Optimizer state loaded.")
def load_checkpoint(filename: str, model: nn.Module,
                    optimizer: optim.Optimizer):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

        return start_epoch, model, optim
Beispiel #13
0
def train_epoch(train_examples,
                train_queue,
                valid_queue,
                model,
                architect,
                criterion,
                optimizer: optim.Optimizer,
                regularizer: Regularizer,
                batch_size: int,
                lr,
                verbose: bool = True):
    loss = nn.CrossEntropyLoss(reduction='mean')
    # print('avg entity embedding norm', torch.norm(model.embeddings[0].weight,dim=1).mean())
    # print('avg relation embedding norm', torch.norm(model.embeddings[1].weight,dim=1).mean())
    with tqdm.tqdm(total=train_examples.shape[0],
                   unit='ex',
                   disable=not verbose) as bar:
        bar.set_description(f'train loss')
        for step, input in enumerate(train_queue):

            model.train()

            input_var = Variable(input, requires_grad=False).cuda()
            target_var = Variable(input[:, 2],
                                  requires_grad=False).cuda()  #async=True)

            input_search = next(iter(valid_queue))
            input_search = Variable(input_search, requires_grad=False).cuda()
            target_search = Variable(input_search[:, 2],
                                     requires_grad=False).cuda()  #async=True)

            model.eval()
            architect.step(input_var,
                           target_var,
                           input_search,
                           target_search,
                           lr,
                           optimizer,
                           unrolled=args.unrolled)
            #set middle identity strength to zero to force learning convolution
            #model._arch_parameters[0].data[2,0]= -1e8
            model.train()
            predictions, factors = model.forward(input_var)
            truth = input_var[:, 2]

            l_fit = loss(predictions, truth)
            #l_reg = regularizer.forward(factors)
            l = l_fit  # + l_reg

            optimizer.zero_grad()
            l.backward()
            nn.utils.clip_grad_norm(model.parameters(), args.grad_clip)
            optimizer.step()

            bar.update(input_var.shape[0])
            bar.set_postfix(loss=f'{l.item():.0f}')
def gradient_step(model: Module, optimiser: Optimizer, loss_fn: Callable,
                  x: torch.Tensor, y: torch.Tensor, **kwargs):
    model.train()
    optimiser.zero_grad()
    y_pred = model(x)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimiser.step()

    return loss, y_pred
Beispiel #15
0
    def optimizer_step(
        self,
        epoch: int,
        batch_idx: int,
        optimizer: optim.Optimizer,
        optimizer_idx: int,
    ) -> None:
        optimizer.step()

        """
Beispiel #16
0
def run_train(model: DeviceAwareModule, train_loader: DataLoader,
              optimizer: Optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(model.device), target.to(model.device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
Beispiel #17
0
def train_batch(xs: Tensor, ys: Tensor, model: nn.Module, criterion: nn.Module,
                optimizer: optim.Optimizer):
    'TODO: docstring'
    model.train()
    optimizer.zero_grad()
    out = model(xs)
    loss = criterion(out, ys)
    loss.backward()
    optimizer.step()
    return loss.item()
def train(model: nn.Module, device: str, train_loader: DataLoader,
          optimizer: optim.Optimizer) -> None:
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
Beispiel #19
0
def train_step(model: MNISTLearner, data: Tensor, target: Tensor,
               optimizer: Optimizer, observer: Observer):
    model.train()
    data, target = data.to(model.device), target.to(model.device)
    optimizer.zero_grad()
    output = model(data)
    loss = F.nll_loss(output, target)
    loss.backward()
    optimizer.step()
    observer.add_scalar('training_loss', loss.mean().item())
Beispiel #20
0
 def _single_batch_train_pass(self, X_batch: TensorType,
                              y_batch: TensorType,
                              optimizer: optim.Optimizer):
     module = self.torch_module
     module.zero_grad()
     optimizer.zero_grad()
     err = self._single_batch_test_pass(X_batch, y_batch)
     err.backward()
     optimizer.step()
     return err
Beispiel #21
0
def torch_single_train(model: PyTorchForecast,
                       opt: optim.Optimizer,
                       criterion: Type[torch.nn.modules.loss._Loss],
                       data_loader: DataLoader,
                       takes_target: bool,
                       meta_data_model: PyTorchForecast,
                       meta_data_model_representation: torch.Tensor,
                       meta_loss=None,
                       multi_targets=1,
                       forward_params: Dict = {}) -> float:
    probablistic = None
    if "probabilistic" in model.params["model_params"]:
        probablistic = True
    print('running torch_single_train')
    i = 0
    output_std = None
    running_loss = 0.0
    for src, trg in data_loader:
        opt.zero_grad()
        # Convert to CPU/GPU/TPU
        src = src.to(model.device)
        trg = trg.to(model.device)
        if meta_data_model:
            representation = meta_data_model.model.generate_representation(meta_data_model_representation)
            forward_params["meta_data"] = representation
            if meta_loss:
                output = meta_data_model.model(meta_data_model_representation)
                met_loss = compute_loss(meta_data_model_representation, output, torch.rand(2, 3, 2), meta_loss, None)
                met_loss.backward()
        if takes_target:
            forward_params["t"] = trg
        output = model.model(src, **forward_params)
        if multi_targets == 1:
            labels = trg[:, :, 0]
        elif multi_targets > 1:
            labels = trg[:, :, 0:multi_targets]
        if probablistic:
            output1 = output
            output = output.mean
            output_std = output1.stddev
        loss = compute_loss(labels, output, src, criterion, None, probablistic, output_std, m=multi_targets)
        if loss > 100:
            print("Warning: high loss detected")
        loss.backward()
        opt.step()
        if torch.isnan(loss) or loss == float('inf'):
            raise ValueError("Error infinite or NaN loss detected. Try normalizing data or performing interpolation")
        running_loss += loss.item()
        i += 1
    print("The running loss is:")
    print(running_loss)
    print("The number of items in train is: ")
    print(i)
    total_loss = running_loss / float(i)
    return total_loss
Beispiel #22
0
    def project(self, latents: Latents, images: torch.Tensor, optimizer: Optimizer, num_steps: int, loss_function: Callable, lr_scheduler: _LRScheduler = None) -> Tuple[LatentPaths, Latents]:
        pbar = tqdm(range(num_steps), leave=False)
        latent_path = []
        noise_path = []

        best_latent = best_noise = best_psnr = None

        for i in pbar:
            img_gen, _ = self.generate(latents)

            batch, channel, height, width = img_gen.shape

            if height > 256:
                factor = height // 256

                img_gen = img_gen.reshape(
                    batch, channel, height // factor, factor, width // factor, factor
                )
                img_gen = img_gen.mean([3, 5])

            # # n_loss = noise_regularize(noises)
            loss, loss_dict = loss_function(img_gen, images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_dict['psnr'] = self.psnr(img_gen, images).item()
            loss_dict['lr'] = optimizer.param_groups[0]["lr"]

            if lr_scheduler is not None:
                lr_scheduler.step()

            self.log.append(loss_dict)

            if best_psnr is None or best_psnr < loss_dict['psnr']:
                best_psnr = loss_dict['psnr']
                best_latent = latents.latent.detach().clone().cpu()
                best_noise = [noise.detach().clone().cpu() for noise in latents.noise]

            if i % self.debug_step == 0:
                latent_path.append(latents.latent.detach().clone().cpu())
                noise_path.append([noise.detach().clone().cpu() for noise in latents.noise])

            loss_description = "; ".join(f"{key}: {value:.6f}" for key, value in loss_dict.items())
            pbar.set_description(loss_description)

            loss_dict['iteration'] = i
            if self.abort_condition is not None and self.abort_condition(loss_dict):
                break

        latent_path.append(latents.latent.detach().clone().cpu())
        noise_path.append([noise.detach().clone().cpu() for noise in latents.noise])

        return LatentPaths(latent_path, noise_path), Latents(best_latent, best_noise)
Beispiel #23
0
def train_epoch(model: Module, optim: Optimizer, grid, img, **kwargs) -> float:
    # Unpack
    mask: Masking = kwargs.get("mask")
    pbar = kwargs.get("pbar")
    lr_scheduler = kwargs.get("lr_scheduler")
    criterion = kwargs.get("criterion", F.mse_loss)
    preconditioner = kwargs.get("preconditioner")

    # Automatic mixed precision
    context = kwargs.get("criterion", _blank_context)
    scaler = kwargs.get("scaler")

    model.train()
    optim.zero_grad()

    # Forward pass
    with context():
        pred = model(grid)
        # Any callable
        train_loss = criterion(
            pred,
            img,
        )

    if scaler:
        # Scales the loss, and calls backward()
        # to create scaled gradients
        scaler.scale(train_loss).backward()
    else:
        train_loss.backward()

    if preconditioner:
        preconditioner.step()

    if mask:
        # If mask, pass the scalar to it
        mask.step(scaler)
    else:
        if scaler:
            # Unscales gradients and calls
            # or skips optimizer.step()
            scaler.step(optim)
            # Updates the scale for next iteration
            scaler.update()
        else:
            optim.step()

    if pbar:
        # Update pbar
        pbar.update(1)

    if lr_scheduler:
        lr_scheduler.step()
    return train_loss.item()
Beispiel #24
0
 def _single_batch_train_pass(self, X_batch: TensorType,
                              y_batch: TensorType,
                              optimizer: optim.Optimizer):
     module = self.torch_module
     optimizer.zero_grad()
     err = self._single_batch_test_pass(X_batch, y_batch)
     err.backward()
     if self.clip_grad:
         clip_grad_norm(module.parameters(), self.clip_grad_norm)
     optimizer.step()
     return err
Beispiel #25
0
def train(
    train_loader: DataLoader,
    model: nn.Module,
    criterion_cls: nn.Module,
    criterion_bound: nn.Module,
    lambda_bound_loss: float,
    optimizer: optim.Optimizer,
    epoch: int,
    device: str,
) -> float:
    losses = AverageMeter("Loss", ":.4e")

    # switch training mode
    model.train()

    for i, sample in enumerate(train_loader):
        x = sample["feature"]
        t = sample["label"]
        b = sample["boundary"]
        mask = sample["mask"]

        x = x.to(device)
        t = t.to(device)
        b = b.to(device)
        mask = mask.to(device)

        batch_size = x.shape[0]

        # compute output and loss
        output_cls, output_bound = model(x)

        loss = 0.0
        if isinstance(output_cls, list):
            n = len(output_cls)
            for out in output_cls:
                loss += criterion_cls(out, t, x) / n
        else:
            loss += criterion_cls(output_cls, t, x)

        if isinstance(output_bound, list):
            n = len(output_bound)
            for out in output_bound:
                loss += lambda_bound_loss * criterion_bound(out, b, mask) / n
        else:
            loss += lambda_bound_loss * criterion_bound(output_bound, b, mask)

        # record loss
        losses.update(loss.item(), batch_size)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    return losses.avg
Beispiel #26
0
    def bw_step(self, loss: torch.Tensor, optimizer: optim.Optimizer):
        if optimizer is None:
            return

        loss.backward(gradient=1 / self.steps)
        if self.is_start_cycle:
            optimizer.zero_grad()
        if self.is_end_cycle:
            optimizer.step()

        self.inc_counter()
Beispiel #27
0
def train(net: nn.Module, dataloader: DataLoader, optimizer: optim.Optimizer, device: torch.device, epoch, print_freq=100):
    total_batches = len(dataloader)

    data_load_time   = AverageMeter()
    batch_total_time = AverageMeter()
    losses           = AverageMeter()
    accs             = AverageMeter()

    start_time_stamp = time.time()

    net.train()
    for i, (image_ids, images, captions, cap_lens, all_captionss) in enumerate(dataloader):
        # image_ids: tuple, shape=(batch_size,)
        # images: shape=(batch_size, 3, H, W)
        # captions: word indexes, shape=(batch_size, cap_len)
        # cap_lens: true caption length, shape=(batch_size,)
        # all_captionss: tuple of list of list, shape=(batch_size, n_catrgories, cap_len)
        data_load_time.update(time.time() - start_time_stamp)

        # Move to GPU if available
        images   = images.to(device)
        captions = captions.to(device)
        cap_lens = cap_lens.to(device)

        # Forward prop + backward prop
        optimizer.zero_grad()
        # predictions' shape=(batch_size, max_cap_len, vocab_size)
        # alphas' shape=(batch_size, num_pixels)
        # predictions usual end with <eos> but <bos> is excluded
        # compute loss and acc
        predictions, alphas = net(images, captions, cap_lens)
        loss, acc = net.compute_loss(predictions, alphas, captions, cap_lens)

        loss.backward()
        optimizer.step()

        # Update recorder
        batch_total_time.update(time.time() - start_time_stamp)
        losses.update(loss.item())
        accs.update(acc.item())

        start_time_stamp = time.time()

        if i % print_freq == 0:
            print('[INFO] Epoch: {0} | Batches: {1}/{2} | '
                'Data load time: {data_load_time.current:.3f} ({data_load_time.avg:.3f}) | '
                'Batch time: {batch_total_time.current:.3f} ({batch_total_time.avg:.3f}) | '
                'Loss: {losses.current:.4f} ({losses.avg:.4f}) | '
                'Acc: {accs.current:.4f} ({accs.avg:.4f})'
                .format(epoch, i, total_batches,
                    data_load_time=data_load_time,
                    batch_total_time=batch_total_time,
                    losses=losses,
                    accs=accs))
 def optimizer_step(
     self,
     model: Union["pl.LightningModule", Module],
     optimizer: Optimizer,
     optimizer_idx: int,
     closure: Callable[[], Any],
     **kwargs: Any,
 ) -> None:
     """Hook to run the optimizer step."""
     if isinstance(model, pl.LightningModule):
         closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
     optimizer.step(closure=closure, **kwargs)
def train_classifier(net: Net,
                     loader: DataLoader,
                     loss_fn: Callable[[Tensor, Tensor], Tensor],
                     optimizer: Optimizer,
                     device: torch.device,
                     log_every: int,
                     early_stopping: float,
                     epochs: int,
                     prefix: str = '',
                     leave: bool = True,
                     clip_gradient: float = 0.1,
                     ) -> None:
    if prefix:
        prefix += ' '           # add trailing space
    net.train()
    iterator = count() if early_stopping > 0 else range(epochs)
    for epoch in tqdm(iterator, leave=leave, desc='Training'):
        running_loss = 0.0
        correct = total = 0
        for index, data in enumerate(loader):
            inputs = data[0].to(device=device)
            labels = data[1].to(device=device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()

            # clip gradients
            if clip_gradient > 0:
                for param in net.parameters():
                    if param.grad is not None:
                        param.grad.data.clamp_(min=-clip_gradient,
                                               max=clip_gradient)

            optimizer.step()
            with torch.no_grad():
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                running_loss += loss.item()
                if index % log_every == log_every - 1:
                    logger.info('[%sEpoch %d Batch %5d] loss: %.3f',
                                prefix, epoch, index, running_loss / log_every)
                    running_loss = 0
        accuracy = correct / total
        logger.info('[Epoch %d] Training Accuracy: %.3f', epoch, accuracy)
        if early_stopping > 0:
            if accuracy > early_stopping:
                return
            if (epoch % 50 == 49) and (accuracy < 0.2):
                # not converging
                logger.warning('Network not converging. Reset Parameters.')
                net.reset_parameters()
def resume_checkpoint(config: Config,
                      model: Module,
                      optimizer: Optimizer = None) -> int:
    """
    resume training process data from config.logs which generated by make_checkpoint()
    :return number of last epoch
    """
    last_epoch = -1
    temp_weight_path = config.temp_weight_path
    temp_optim_path = config.temp_optim_path
    if os.path.exists(config.train_record_file):
        try:
            with open(config.train_record_file, 'r') as f:
                last = f.readlines()[-1]
                import json
                info = json.loads(last)
                last_epoch = int(info["epoch"])
                last_init = str(info["init"])
                if not os.path.exists(temp_weight_path):
                    temp_weight_path = temp_weight_path.replace(
                        config.init_time, last_init)
                if not os.path.exists(temp_optim_path):
                    temp_optim_path = temp_optim_path.replace(
                        config.init_time, last_init)
            print("Continue train from last epoch %d" % last_epoch)
        except:
            warn("Rename invalid train record file from {} to {}".format(
                config.train_record_file,
                config.train_record_file + '.badfile'))
            warn("Can't get last_epoch value, {} will be returned".format(
                last_epoch))
            os.rename(config.train_record_file,
                      config.train_record_file + '.badfile')
    if os.path.exists(temp_weight_path):
        try:
            model.load_state_dict(load(temp_weight_path))
            print("Resumed weight checkpoint from {}".format(temp_weight_path))
        except:
            warn("Move invalid temp {} weights file from {} to {}".format(
                type(model), temp_weight_path, temp_weight_path + '.badfile'))
            os.rename(temp_weight_path, temp_weight_path + '.badfile')
    if optimizer is not None and os.path.exists(temp_optim_path):
        try:
            optimizer.load_state_dict(load(temp_optim_path))
            print(
                "Resumed optimizer checkpoint from {}".format(temp_optim_path))
        except:
            warn("Move invalid temp {} weights file from {} to {}".format(
                type(optimizer), temp_optim_path,
                temp_optim_path + '.badfile'))
            os.rename(temp_optim_path, temp_optim_path + '.badfile')

    return last_epoch