Beispiel #1
0
def train_step(
        model: FlowModel,
        config: TrainConfig,
        action: ActionFn,
        optimizer: optim.Optimizer,
        batch_size: int,
        scheduler: Any = None,
        scaler: GradScaler = None,
        pre_model: FlowModel = None,
        dkl_factor: float = 1.,
        xi: torch.Tensor = None,
):
    """Perform a single training step.

    TODO: Add `torch.device` to arguments for DDP.
    """
    t0 = time.time()
    #  layers, prior = model['layers'], model['prior']
    optimizer.zero_grad()

    loss_dkl = torch.tensor(0.0)
    if torch.cuda.is_available():
        loss_dkl = loss_dkl.cuda()

    if pre_model is not None:
        pre_xi = pre_model.prior.sample_n(batch_size)
        x = qed.ft_flow(pre_model.layers, pre_xi)
        xi = qed.ft_flow_inv(pre_model.layers, x)

    #  with torch.cuda.amp.autocast():
    x, xi, logq = apply_flow_to_prior(model.prior,
                                      model.layers,
                                      xi=xi, batch_size=batch_size)
    logp = (-1.) * action(x)
    dkl = calc_dkl(logp, logq)

    ess = calc_ess(logp, logq)
    qi = qed.batch_charges(xi)
    q = qed.batch_charges(x)
    plaq = logp / (config.beta * config.volume)
    dq = torch.sqrt((q - qi) ** 2)

    loss_dkl = dkl_factor * dkl

    if scaler is not None:
        scaler.scale(loss_dkl).backward()
        scaler.step(optimizer)
        scaler.update()
    else:
        loss_dkl.backward()
        optimizer.step()

    if scheduler is not None:
        scheduler.step(loss_dkl)

    metrics = {
        'dt': time.time() - t0,
        'ess': grab(ess),
        'logp': grab(logp),
        'logq': grab(logq),
        'loss_dkl': grab(loss_dkl),
        'q': grab(q),
        'dq': grab(dq),
        'plaq': grab(plaq),
    }

    return metrics
def train_model(
    train_dl: data.DataLoader,
    dev_dl: data.DataLoader,
    model: nn.Module,
    optimizer: optim.Optimizer,
    lr_scheduler: optim.lr_scheduler._LRScheduler,
    args: argparse.Namespace,
) -> nn.Module:

    device = model_utils.get_device()
    # loss_fn = nn.functional.binary_cross_entropy
    loss_fn = model_utils.l1_norm_loss
    val_loss_fn = model_utils.l1_norm_loss
    best_val_loss = torch.tensor(float('inf'))
    saved_checkpoints = []
    writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}')
    scalar_rand = torch.distributions.uniform.Uniform(0.5, 1.5)

    for e in range(1, args.train_epochs + 1):
        print(f'Training epoch {e}...')

        # Training portion
        torch.cuda.empty_cache()
        with tqdm(total=args.train_batch_size * len(train_dl)) as progress_bar:
            model.train()
            for i, (x_batch, y_batch_biden, y_batch_trump,
                    _) in enumerate(train_dl):
                # trump_scale = scalar_rand.sample()
                # biden_scale = scalar_rand.sample()
                # y_batch_biden = y_batch_biden * biden_scale
                # y_batch_trump = y_batch_trump * trump_scale
                # x_batch = (y_batch_trump + y_batch_biden).abs().to(device)
                x_batch = x_batch.abs().to(device)
                y_batch_biden = y_batch_biden.abs().to(device)
                y_batch_trump = y_batch_trump.abs().to(device)

                # Forward pass on model
                optimizer.zero_grad()
                y_pred_b, y_pred_t = model(x_batch)
                if args.train_trump:
                    # loss = loss_fn(y_pred_t * x_batch, y_batch_trump)
                    loss = loss_fn(y_pred_t, y_batch_trump)
                else:
                    # loss = loss_fn(y_pred_b * x_batch, y_batch_biden)
                    loss = loss_fn(y_pred_b, y_batch_biden)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()
                if args.use_scheduler:
                    lr_scheduler.step(loss)

                progress_bar.update(len(x_batch))
                progress_bar.set_postfix(loss=loss.item())
                writer.add_scalar("train/Loss", loss,
                                  ((e - 1) * len(train_dl) + i) *
                                  args.train_batch_size)

                del x_batch
                del y_batch_biden
                del y_batch_trump
                del y_pred_b
                del y_pred_t
                del loss

        # Validation portion
        torch.cuda.empty_cache()
        with tqdm(total=args.val_batch_size * len(dev_dl)) as progress_bar:
            model.eval()
            val_loss = 0.0
            num_batches_processed = 0
            for i, (x_batch, y_batch_biden, y_batch_trump,
                    _) in enumerate(dev_dl):
                x_batch = x_batch.abs().to(device)
                y_batch_biden = y_batch_biden.abs().to(device)
                y_batch_trump = y_batch_trump.abs().to(device)

                # Forward pass on model
                y_pred_b, y_pred_t = model(x_batch)
                # y_pred_b_mask = torch.ones_like(y_pred_b) * (y_pred_b > args.alpha)
                # y_pred_t_mask = torch.ones_like(y_pred_t) * (y_pred_t > args.alpha)
                y_pred_b_mask = torch.clamp(y_pred_b / x_batch, 0, 1)
                y_pred_t_mask = torch.clamp(y_pred_t / x_batch, 0, 1)

                loss_trump = val_loss_fn(y_pred_t_mask * x_batch,
                                         y_batch_trump)
                loss_biden = val_loss_fn(y_pred_b_mask * x_batch,
                                         y_batch_biden)

                if args.train_trump:
                    val_loss += loss_trump.item()
                else:
                    val_loss += loss_biden.item()
                num_batches_processed += 1

                progress_bar.update(len(x_batch))
                progress_bar.set_postfix(val_loss=val_loss /
                                         num_batches_processed)
                writer.add_scalar("Val/Biden Loss", loss_biden,
                                  ((e - 1) * len(dev_dl) + i) *
                                  args.val_batch_size)
                writer.add_scalar("Val/Trump Loss", loss_trump,
                                  ((e - 1) * len(dev_dl) + i) *
                                  args.val_batch_size)

                del x_batch
                del y_batch_biden
                del y_batch_trump
                del y_pred_b
                del y_pred_t
                del loss_trump
                del loss_biden

            # Save model if it's the best one yet.
            if val_loss / num_batches_processed < best_val_loss:
                best_val_loss = val_loss / num_batches_processed
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model saved!')
                print(f'Best validation loss yet: {best_val_loss}')
            # Save model on checkpoints.
            if e % args.checkpoint_freq == 0:
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model checkpoint reached!')
                saved_checkpoints.append(filename)
                # Delete checkpoints if there are too many
                while len(saved_checkpoints) > args.num_checkpoints:
                    os.remove(saved_checkpoints.pop(0))

    return model
Beispiel #3
0
def train(loader: DataLoader,
          model: torch.nn.Module,
          criterion,
          optimizer: Optimizer,
          epoch: int,
          device,
          print_freq,
          display=False):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs = inputs.to(device)
        targets = targets.to(device)

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0 and display == True:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    return (losses.avg, top1.avg)
Beispiel #4
0
def update_theta(epoch: int, baseline: MovingAverageMetric, entropy_coeff,
                 grad_clip: int, data_loader, device: str,
                 master_pair: MasterPairs, architecture: NASNetwork,
                 optimizer: optim.Optimizer, writer: SummaryWriter,
                 log_frequency: int):
    start = datetime.now()
    policy_loss_metric = AverageMetric()
    accuracy_metric = AccuracyMetric(topk=(1, 5))
    normal_logp_metric = AverageMetric()
    node_normal_entropy_metric = AverageMetric()
    op_normal_entropy_metric = AverageMetric()
    reduced_logp_metric = AverageMetric()
    node_reduced_entropy_metric = AverageMetric()
    op_reduced_entropy_metric = AverageMetric()

    node_normal_entropy_coeff, node_reduced_entropy_coeff, \
        op_normal_entropy_coeff, op_reduced_entropy_coeff = [entropy_coeff, ]*4

    master_pair.unset_force_uniform()

    for iter_, (datas, targets) in enumerate(data_loader, start=1):
        datas, targets = datas.to(device=device), targets.to(device=device)
        (normal_arch, normal_logp, node_normal_entropy, op_normal_entropy), \
            (reduced_arch, reduced_logp, node_reduced_entropy,
             op_reduced_entropy) = master_pair()
        with torch.no_grad():
            outputs = architecture(datas, normal_arch, reduced_arch)
        accuracy_metric.update(targets, outputs)
        accuracy_1 = accuracy_metric.last_accuracy(1).rate
        baseline.update(accuracy_1)
        reward = accuracy_1 - baseline.value
        policy_loss = -(normal_logp + reduced_logp) * reward \
            - (node_normal_entropy*node_normal_entropy_coeff
               + op_normal_entropy*op_normal_entropy_coeff
                + node_reduced_entropy*node_reduced_entropy_coeff
                + op_reduced_entropy*op_reduced_entropy_coeff)

        optimizer.zero_grad()
        policy_loss.backward()
        if grad_clip is not None:
            nn.utils.clip_grad_norm_(master_pair.parameters(), grad_clip)
        optimizer.step()

        # update metrics
        policy_loss_metric.update(policy_loss)
        normal_logp_metric.update(normal_logp)
        node_normal_entropy_metric.update(node_normal_entropy)
        op_normal_entropy_metric.update(op_normal_entropy)
        reduced_logp_metric.update(reduced_logp)
        node_reduced_entropy_metric.update(node_reduced_entropy)
        op_reduced_entropy_metric.update(op_reduced_entropy)

        # iteration log
        if iter_ % log_frequency == 0 or iter_ == len(data_loader):
            message = f"UPDATE THETA, epoch={epoch:03d}, Iter={iter_}/{len(data_loader)}, "
            message += f"reward={reward:.4f}, "
            message += f"pocily loss={policy_loss_metric.last:.4f}({policy_loss_metric.value:.4f}), "
            message += f"moving accuracy={baseline.value*100:.2f}%, "
            message += f"normal_logp={normal_logp_metric.last:.4f}({normal_logp_metric.value:.4f}), "
            message += f"node_normal_entropy={node_normal_entropy_metric.last:.4f}({node_normal_entropy_metric.value:.4f}), "
            message += f"op_normal_entropy={op_normal_entropy_metric.last:.4f}({op_normal_entropy_metric.value:.4f}), "
            message += f"reduced_logp={reduced_logp_metric.last:.4f}({reduced_logp_metric.value:.4f}), "
            message += f"node_reduced_entropy={node_reduced_entropy_metric.last:.4f}({node_reduced_entropy_metric.value:.4f}), "
            message += f"op_reduced_entropy={op_reduced_entropy_metric.last:.4f}({op_reduced_entropy_metric.value:.4f})."
            if iter_ == len(data_loader):
                message += f" Eplased time={datetime.now()-start}."
            utils.logger.info(message)

    writer.add_scalar("update_theta/policy_loss", policy_loss_metric.value,
                      epoch)
    writer.add_scalar("update_theta/baseline", baseline.value, epoch)
    writer.add_scalar("update_theta/accuracy@1",
                      accuracy_metric.accuracy(1).rate, epoch)
    writer.add_scalar("update_theta/accuracy@5",
                      accuracy_metric.accuracy(5).rate, epoch)
    writer.add_scalar("update_theta/normal_logp", normal_logp_metric.value,
                      epoch)
    writer.add_scalar("update_theta/node_normal_entropy",
                      node_normal_entropy_metric.value, epoch)
    writer.add_scalar("update_theta/op_normal_entropy",
                      op_normal_entropy_metric.value, epoch)
    writer.add_scalar("update_theta/reduced_logp", reduced_logp_metric.value,
                      epoch)
    writer.add_scalar("update_theta/node_reduced_entropy",
                      node_reduced_entropy_metric.value, epoch)
    writer.add_scalar("update_theta/op_reduced_entropy",
                      op_reduced_entropy_metric.value, epoch)
def train(
    model: nn.Module,
    num_epochs: int,
    dataloader: DataLoader,
    optimizer: Optimizer,
    lr_scheduler: Optional[LRScheduler] = None,
    num_gradient_accumulation_steps: Optional[int] = 1,
    max_gradient_norm: Optional[float] = None,
    device: Optional[torch.device] = torch.device('cpu'),
    local_rank: Optional[int] = 0,
    use_distributed: Optional[bool] = False,
    is_master: Optional[bool] = True,
    use_tqdm: Optional[bool] = True,
    logger: Optional[Logger] = None,
) -> None:
    # put model in train mode
    model.train()

    # keep track of the last loss
    last_loss = 0

    for epoch in range(num_epochs):
        # synchronize all processes
        if use_distributed:
            dist.barrier()

        if is_master and logger is not None:
            logger.info(f'Starting with epoch {epoch+1}/{num_epochs}')

        # initialize the progress bar
        if is_master and use_tqdm:
            pbar = tqdm(
                desc=f'Training [epoch {epoch+1}/{num_epochs}]',
                total=len(dataloader),
                unit='batch',
            )

        for step, batch in enumerate(dataloader):
            # unpack batch
            sequences, attention_masks, _, start_positions, end_positions, _, _, _ = batch

            # send sequences, attention_masks, start_positions and end_positions to device
            sequences = sequences.to(device)
            attention_masks = attention_masks.to(device)
            start_positions = start_positions.to(device)
            end_positions = end_positions.to(device)

            # forward pass (loss computation included)
            outputs = model(input_ids=sequences,
                            attention_mask=attention_masks,
                            start_positions=start_positions,
                            end_positions=end_positions)
            loss = outputs[0]
            last_loss = loss.item()

            if use_distributed:
                loss = loss.mean()

            # rescale the loss
            loss /= num_gradient_accumulation_steps

            # backward pass
            loss.backward()

            if step % num_gradient_accumulation_steps == 0:
                # clip the gradient
                if max_gradient_norm is not None:
                    clip_grad_norm_(model.parameters(), max_gradient_norm)

                # update the parameters
                optimizer.step()
                if lr_scheduler is not None:
                    lr_scheduler.step()

                # clear all gradients
                optimizer.zero_grad()

            # update the progress bar
            if is_master and use_tqdm:
                pbar.update()
                pbar.set_postfix({'last_loss': last_loss})

        # close the progress bar
        if is_master and use_tqdm:
            pbar.close()
Beispiel #6
0
    def adv_train(self, epoch: int, optimizer: optim.Optimizer, lr_scheduler: optim.lr_scheduler._LRScheduler = None,
                  validate_interval=10, save=False, verbose=True, indent=0,
                  **kwargs):
        loader_train = self.dataset.loader['train']
        file_path = os.path.join(self.folder_path, self.get_filename() + '.pth')

        _, best_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)

        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        params: list[nn.Parameter] = []
        for param_group in optimizer.param_groups:
            params.extend(param_group['params'])
        for _epoch in range(epoch):
            losses.reset()
            top1.reset()
            top5.reset()
            epoch_start = time.perf_counter()
            if verbose and env['tqdm']:
                loader_train = tqdm(loader_train)
            self.model.activate_params(params)
            optimizer.zero_grad()
            for data in loader_train:
                _input, _label = self.model.get_data(data)
                noise = torch.zeros_like(_input)

                def loss_fn(X: torch.FloatTensor):
                    return -self.model.loss(X, _label)
                adv_x = _input
                self.model.train()
                loss = self.model.loss(adv_x, _label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                for m in range(self.pgd.iteration):
                    self.model.eval()
                    adv_x, _ = self.pgd.optimize(_input=_input, noise=noise, loss_fn=loss_fn, iteration=1)
                    optimizer.zero_grad()
                    self.model.train()
                    loss = self.model.loss(adv_x, _label)
                    loss.backward()
                    optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    _output = self.model.get_logits(_input)
                acc1, acc5 = self.model.accuracy(_output, _label, topk=(1, 5))
                batch_size = int(_label.size(0))
                losses.update(loss.item(), batch_size)
                top1.update(acc1, batch_size)
                top5.update(acc5, batch_size)
            epoch_time = str(datetime.timedelta(seconds=int(
                time.perf_counter() - epoch_start)))
            self.model.eval()
            self.model.activate_params([])
            if verbose:
                pre_str = '{blue_light}Epoch: {0}{reset}'.format(
                    output_iter(_epoch + 1, epoch), **ansi).ljust(64 if env['color'] else 35)
                _str = ' '.join([
                    f'Loss: {losses.avg:.4f},'.ljust(20),
                    f'Top1 Clean Acc: {top1.avg:.3f}, '.ljust(30),
                    f'Top5 Clean Acc: {top5.avg:.3f},'.ljust(30),
                    f'Time: {epoch_time},'.ljust(20),
                ])
                prints(pre_str, _str, prefix='{upline}{clear_line}'.format(**ansi) if env['tqdm'] else '',
                       indent=indent)
            if lr_scheduler:
                lr_scheduler.step()

            if validate_interval != 0:
                if (_epoch + 1) % validate_interval == 0 or _epoch == epoch - 1:
                    _, cur_acc = self.validate_fn(verbose=verbose, indent=indent, **kwargs)
                    if cur_acc < best_acc:
                        prints('best result update!', indent=indent)
                        prints(f'Current Acc: {cur_acc:.3f}    Previous Best Acc: {best_acc:.3f}', indent=indent)
                        best_acc = cur_acc
                    if save:
                        self.model.save(file_path=file_path, verbose=verbose)
                    if verbose:
                        print('-' * 50)
        self.model.zero_grad()
Beispiel #7
0
def train(loader: DataLoader, model: torch.nn.Module, criterion,
          optimizer: Optimizer, epoch: int, args):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()
    m = Bernoulli(torch.tensor([args.calibrated_alpha]).cuda())

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs = inputs.cuda()
        targets = targets.cuda()

        # make MNIST binary
        if args.dataset == 'mnist':
            inputs = (inputs > 0.5).type(torch.cuda.FloatTensor)

        # augment inputs with noise
        if args.perturb == 'bernoulli':
            mask = m.sample(inputs.shape).squeeze(-1)
            # make sure that the value is normalized
            rand_inputs = torch.randint_like(
                inputs, low=0, high=args.K + 1, device='cuda') / float(args.K)
            inputs = inputs * mask + rand_inputs * (1 - mask)
        elif args.perturb == 'gaussian':
            inputs = inputs + torch.randn_like(inputs,
                                               device='cuda') * args.sigma

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if (i + 1) % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i + 1,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    return (losses.avg, top1.avg)
Beispiel #8
0
def train(model: nn.Module,
          optimizer: optim.Optimizer,
          train_data: DataLoader,
          use_cuda: bool = True,
          scheduler=None,
          bce_weight: float = 1,
          mse_weight: float = 0.1,
          misclass_weight: float = 1,
          corclass_weight: float = 1 ,
          threshold: float = 0.7,
          gci_threshold: float = 0.5):
    model.train()
    loss_sum = 0
    bce_loss = 0
    mse_loss = 0
    gci_misclass = 0
    misses = 0
    bce_weight = Variable(th.Tensor([bce_weight]))
    mse_weight = Variable(th.Tensor([mse_weight]))
    misclass_weight = Variable(th.Tensor([misclass_weight]))
    corclass_weight = Variable(th.Tensor([corclass_weight]))
    thresh = Variable(th.Tensor([threshold]))
    gci_thresh = Variable(th.Tensor([gci_threshold]))
    batches = len(train_data)

    if use_cuda:
        if th.cuda.is_available():
            model.cuda()
        else:
            print('Warning: GPU not available, Running on CPU')
    for data, target in train_data:
        if scheduler is not None:
            scheduler.step()

        if use_cuda:
            data, target = data.cuda(), target.cuda()
            bce_weight = bce_weight.cuda()
            mse_weight = mse_weight.cuda()
            misclass_weight = misclass_weight.cuda()
            corclass_weight = corclass_weight.cuda()
            thresh_val = thresh.cuda()
            gci_thresh = gci_thresh.cuda()

        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        # print(len(data))

        peak_distance_target = target[:, 0]
        peak_indicator_target = target[:, 1]
        output = model(data)

        distance = (output[:, 1])
        probabilities = output[:, 0]

        loss_bce = F.binary_cross_entropy_with_logits(probabilities,
                                                      peak_indicator_target)
        # print(loss_bce, loss_bce.mean())
        loss_mse = (distance * peak_indicator_target - peak_distance_target * peak_indicator_target) ** 2
        loss_mse = loss_mse.sum()/peak_indicator_target.sum()
        out = (F.sigmoid(probabilities) > gci_thresh).float()
        loss_misclass = (1 - peak_indicator_target) * (
            F.sigmoid(probabilities)**2)
        # loss_misclass = (1 - peak_indicator_target) * (out)
        loss_misclass = loss_misclass.mean()

        misses_temp = (1 - peak_indicator_target) * out
        misses += misses_temp.mean().data[0]

        out = (F.sigmoid(probabilities) > gci_thresh).float()
        gci_misclass_temp = peak_indicator_target * (1 - out)
        gci_misclass += gci_misclass_temp.mean().data[0]

        loss_corrclass = peak_indicator_target * ((
            1 - F.sigmoid(probabilities))**2)
        loss_corrclass = loss_corrclass.mean()

        net_loss = bce_weight * loss_bce + mse_weight * loss_mse

        loss_sum += net_loss.data[0]
        bce_loss += loss_bce.data[0]
        mse_loss += loss_mse.data[0]

        net_loss.backward()
        # TODO: Gradient Clipping
        optimizer.step()
    return loss_sum / batches , bce_loss / batches , mse_loss / batches , gci_misclass / batches, misses / batches
def train(model_G: nn.Module,
          model_D: nn.Module,
          optimizer_G: optim.Optimizer,
          optimizer_D: optim.Optimizer,
          train_data: DataLoader,
          use_cuda: bool = True):
    model_G.train()
    model_D.train()
    loss_sum = 0
    loss_D = 0
    loss_G = 0
    D_real_prob = 0
    D_fake_prob = 0
    batches = len(train_data)

    if use_cuda:
        if th.cuda.is_available():
            model_G.cuda()
            model_D.cuda()
        else:
            print('Warning: GPU not available, Running on CPU')

    for x_train, y_train in train_data:
        if use_cuda:
            y_train = y_train.type(th.LongTensor).cuda()
            x_train, y_train = x_train.cuda(), y_train.cuda()

        batch_size = x_train.shape[0]

        x_train, y_train = Variable(x_train), Variable(y_train)
        optimizer_G.zero_grad()
        optimizer_D.zero_grad()

        # Training the DISCRIMINATOR
        z = Variable(th.randn(batch_size, 1)).cuda()
        ones_label = Variable(th.ones(batch_size, 1)).cuda()
        zeros_label = Variable(th.zeros(batch_size, 1)).cuda()
        image = model_G(z)

        D_real = model_D(x_train)
        D_fake = model_D(image)

        D_loss_real = F.binary_cross_entropy(D_real, ones_label)
        D_loss_fake = F.binary_cross_entropy(D_fake, zeros_label)
        D_loss = D_loss_real + D_loss_fake

        D_real_prob += D_real.mean().item()
        D_fake_prob += D_fake.mean().item()

        D_loss.backward()
        optimizer_D.step()
        optimizer_D.zero_grad()
        optimizer_G.zero_grad()

        # Training the GENERATOR
        for i in range(10):
            z = Variable(th.randn(batch_size, 1)).cuda()
            image = model_G(z)
            D_fake = model_D(image)
            G_loss = F.binary_cross_entropy(D_fake, ones_label)

            G_loss.backward()
            optimizer_G.step()
            optimizer_D.zero_grad()
            optimizer_G.zero_grad()

        loss_D += D_loss.item()
        loss_G += G_loss.item()
        loss_sum += loss_D + loss_G
    th.cuda.empty_cache()

    return loss_sum / batches, loss_D / batches, loss_G / batches, D_real_prob / batches, D_fake_prob / batches
Beispiel #10
0
def train(epoch: int, model: nn.Module, loader: data.DataLoader,
          criterion: nn.modules.loss._Loss, optimizer: optim.Optimizer,
          scheduler: optim.lr_scheduler._LRScheduler, only_epoch_sche: bool,
          use_amp: bool, accmulated_steps: int, device: str,
          log_interval: int):
    model.train()

    scaler = GradScaler() if use_amp else None

    gradident_accumulator = GradientAccumulator(accmulated_steps)

    loss_metric = AverageMetric("loss")
    accuracy_metric = AccuracyMetric(topk=(1, 5))
    ETA = EstimatedTimeArrival(len(loader))
    speed_tester = SpeedTester()

    lr = optimizer.param_groups[0]['lr']
    _logger.info(f"Train start, epoch={epoch:04d}, lr={lr:.6f}")

    for time_cost, iter_, (inputs, targets) in time_enumerate(loader, start=1):
        inputs, targets = inputs.to(device=device), targets.to(device=device)

        optimizer.zero_grad()

        with autocast(enabled=use_amp):
            outputs = model(inputs)
            loss: torch.Tensor = criterion(outputs, targets)

        gradident_accumulator.backward_step(model, loss, optimizer, scaler)

        if scheduler is not None:
            if only_epoch_sche:
                if iter_ == 1:
                    scheduler.step()
            else:
                scheduler.step()

        loss_metric.update(loss)
        accuracy_metric.update(outputs, targets)
        ETA.step()
        speed_tester.update(inputs)

        if iter_ % log_interval == 0 or iter_ == len(loader):
            _logger.info(", ".join([
                "TRAIN",
                f"epoch={epoch:04d}",
                f"iter={iter_:05d}/{len(loader):05d}",
                f"fetch data time cost={time_cost*1000:.2f}ms",
                f"fps={speed_tester.compute()*world_size():.0f} images/s",
                f"{loss_metric}",
                f"{accuracy_metric}",
                f"{ETA}",
            ]))
            speed_tester.reset()

    return {
        "lr": lr,
        "train/loss": loss_metric.compute(),
        "train/top1_acc": accuracy_metric.at(1).rate,
        "train/top5_acc": accuracy_metric.at(5).rate,
    }
Beispiel #11
0
def train(args, train_data_loader: DataLoader, valid_data_loader: DataLoader,
          model: Net, criterion, optimizer: optim.Optimizer, device):
    # save model
    if args.save_model:
        if not os.path.exists(args.save_directory):
            os.makedirs(args.save_directory)
    epochs = args.epochs

    train_losses = []
    valid_losses = []
    for epoch_id in range(epochs):
        # train_loss = 0.0
        # valid_loss = 0.0
        ######################
        # training the model #
        ######################
        model.train()
        train_batch_cnt = 0
        train_mean_pts_loss = 0.0
        for batch_idx, batch in enumerate(train_data_loader):
            train_batch_cnt += 1
            img = batch['image']
            landmark = batch['landmarks']

            # ground truth
            input_img = img.to(device)
            target_pts = landmark.to(device)

            # clear the gradients of all optimized variables(torch.Tensor)
            optimizer.zero_grad()
            output_pts = model(input_img)
            loss = criterion(output_pts, target_pts)

            # do BP automatically
            loss.backward()
            optimizer.step()  # 更新优化器中的参数
            train_mean_pts_loss += loss.item()

            # show log info
            if batch_idx % args.log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\t pts_loss: {:.6f}'.
                      format(epoch_id, batch_idx * len(img),
                             len(train_loader.dataset),
                             100. * batch_idx / len(train_loader),
                             loss.item()))
        train_mean_pts_loss /= train_batch_cnt
        train_losses.append(train_mean_pts_loss)

        #######################
        # validate the model #
        #######################
        valid_mean_pts_loss = 0.0
        model.eval()  #prepare model for evaluation
        with torch.no_grad():
            valid_batch_cnt = 0
            for valid_batch_idx, batch in enumerate(valid_data_loader):
                valid_batch_cnt += 1
                valid_img = batch['image']
                landmark = batch['landmarks']

                input_img = valid_img.to(device)
                target_pts = landmark.to(device)

                output_pts = model(input_img)
                valid_loss = criterion(output_pts, target_pts)
                valid_mean_pts_loss += valid_loss.item()
            valid_mean_pts_loss /= valid_batch_cnt * 1.0
            valid_losses.append(valid_mean_pts_loss)

            print('Valid: pts_loss: {:.6f}'.format(valid_mean_pts_loss))
        print('====================================================')
        if args.save_model:
            saved_model_name = os.path.join(
                args.save_directory,
                # f'detector_epoch_{epoch_id}_{train_mean_pts_loss}_{valid_mean_pts_loss}.pt')
                f'detector_epoch_{args.phase}_{epoch_id}.pt')
            torch.save(model.state_dict(), saved_model_name)
        draw_loss(train_losses, valid_losses, args.phase)

    return train_losses, valid_losses
    def train_epoch(self,
                    dataset: TLGDataset,
                    batch_size: int,
                    criterion: Callable[[FloatTensor, LongTensor],
                                        FloatTensor],
                    optimizer: optim.Optimizer,
                    train_indices: List[int],
                    n_print=100,
                    epoch=0) -> Tuple[float, int, int, int, int]:
        self.train()

        permutation = np.random.permutation(train_indices)

        batch_start = 0
        loss = 0.
        BS, BTS, BW, BTW = 0, 0, 0, 0

        running_batch_time = 0.0

        # while batch_start < len(permutation):
        for i in range(len(permutation)):
            start_time = time.time()
            optimizer.zero_grad()
            # batch_end = min([batch_start + batch_size, len(permutation)])

            # batch_x = [dataset.X[permutation[i]] for i in range(batch_start, batch_end)]
            # batch_y = [dataset.Y[permutation[i]] for i in range(batch_start, batch_end)]
            batch_x = dataset.X[permutation[i]]
            batch_y = dataset.Y[permutation[i]]

            # lens = list(map(len, batch_x))
            lens = torch.sum((batch_x.word != dataset.x_pad_token).long(),
                             dim=1).to(self.device)

            # batch_x = pad_sequence(batch_x, batch_first=True).to(self.device)
            # batch_y = pad_sequence(batch_y, batch_first=True).long().to(self.device)
            batch_e = F.embedding(batch_y.to(self.device),
                                  self.transformer.embedding_matrix)

            encoder_mask = torch.ones(batch_y.shape[0], batch_y.shape[1],
                                      batch_x.shape[1])
            for i, l in enumerate(lens):
                encoder_mask[i, :, l::] = 0
            encoder_mask = encoder_mask.to(self.device)
            decoder_mask = Mask(
                (batch_x.shape[0], batch_y.shape[1], batch_y.shape[1])).to(
                    self.device)  # does this have to be t()?

            batch_p = self.forward(batch_x, batch_e, encoder_mask,
                                   decoder_mask)

            batch_loss = criterion(batch_p[:, :-1].permute(
                0, 2, 1), batch_y[:, 1:].to(self.device)) / lens.float().sum()
            loss += batch_loss.item()
            batch_loss.backward()
            optimizer.step()
            argmaxes = batch_p.argmax(dim=-1)
            # print('pre argmaxes', argmaxes.size(), argmaxes[0])
            # print('pre y', batch_y.size(), batch_y[0])
            argmaxes = argmaxes[:, :-1]
            y = batch_y[:, 1:]
            # print('post argmaxes', argmaxes.size(), argmaxes[0])
            # print('post y', y.size(), y[0])

            (bs, bts), (bw, btw) = accuracy(argmaxes, y.to(self.device),
                                            dataset.type_dict[PAD])
            # (bs, bts), (bw, btw) = accuracy(batch_p[:, :-1].argmax(dim=-1), batch_y[:, 1:], dataset.type_dict[PAD])
            BS += bs
            BTS += bts
            BW += bw
            BTW += btw

            running_batch_time += time.time() - start_time

            if i % n_print == n_print - 1:  # print every n mini-batches

                batch_time = running_batch_time / n_print
                print('[%d, %5d] loss: %.3f | acc: %.3f | %.1f %s | %.1f %s' %
                      (epoch + 1, i + 1, loss / n_print, BTW / BW,
                       batch_time if batch_time >= 1 else 1 / batch_time,
                       's/batch' if batch_time >= 1 else 'batch(es)/s',
                       batch_time / batch_size if batch_time / batch_size >= 1
                       else batch_size / batch_time, 's/expl'
                       if batch_time / batch_size >= 1 else 'expl(s)/s'),
                      file=sys.stderr)
                # if str(device).startswith('cuda'):
                #     print(torch.cuda.memory_summary(abbreviated=False), file=sys.stderr)

                # assist.info['batch'] = train_i + 1
                # assist.info['batch_loss'] = running_loss / n_print
                # assist.info['batch_acc'] = running_acc / n_print
                # assist.info['ex_per_s'] = batch_size / batch_time
                # assist.step()

                # running_loss = 0.0
                # running_acc = 0.0
                running_batch_time = 0.0

            # batch_start += batch_size

        return loss, BS, BTS, BW, BTW
def train(model: nn.Module, optimizer: optim.Optimizer, dataloader: DataLoader, epochs: int,
          loss_criterion: str, model_dir: str, plateau_limit: int, apply_nested_dropout: bool,
          reconstruct: bool, **kwargs):
    print(f'The model has {utils.get_num_parameters(model):,} parameters')
    testloader = kwargs.pop('testloader', None)
    lr_scheduler = kwargs.pop('lr_scheduler', None)

    loss_function = getattr(nn, loss_criterion)()
    batch_print = len(dataloader) // 5

    model.train()
    device = utils.get_device()
    model.to(device)  # TODO check if this actually does anything

    losses = []
    accuracies = []
    best_loss = float('inf')
    best_accuracy = 0
    plateau = 0
    train_time = 0
    for epoch in range(epochs):
        epoch_start = time.time()
        line = f'\tEpoch {epoch + 1}/{epochs}'
        if apply_nested_dropout and epoch > 0:
            line += f' ({model.get_converged_unit()}/{model.get_dropout_dim()} converged units)'
        print(line)

        batch_losses = []
        for i, (X, y) in enumerate(dataloader):
            optimizer.zero_grad()
            X = X.to(device)
            y = y.to(device)
            prediction = model(X)

            if reconstruct:
                loss = loss_function(prediction, X)
            else:
                loss = loss_function(prediction, y)

            loss.backward()
            optimizer.step()

            batch_losses.append(loss.item())
            if (i + 1) % batch_print == 0:
                batch_loss = utils.format_number(np.average(batch_losses[-batch_print:]))
                print(f'Batch {i + 1} loss: {batch_loss}')

            if apply_nested_dropout:
                model(X)
                if model.has_converged():
                    break

        epoch_loss = utils.format_number(np.average(batch_losses))
        losses.append(epoch_loss)

        epoch_time = time.time() - epoch_start
        train_time += epoch_time

        print(f'\tEpoch loss {epoch_loss}')

        model_save_kwargs = dict(**kwargs, epoch=epoch, train_time=utils.format_time(train_time), losses=losses)
        has_improved = False
        if testloader is not None:
            model.eval()
            eval_accuracy = round(utils.get_model_accuracy(model, testloader, device), 3)
            model.train()
            accuracies.append(eval_accuracy)
            print(f'\tEvaluation accuracy {eval_accuracy}')

            if eval_accuracy > best_accuracy:
                best_accuracy = eval_accuracy
                has_improved = True
                model_save_kwargs.update(accuracies=accuracies, best_accuracy=best_accuracy)

            if lr_scheduler is not None:
                lr_scheduler.step(eval_accuracy)

        elif epoch_loss < best_loss:
            best_loss = epoch_loss
            has_improved = True
            model_save_kwargs.update(best_loss=best_loss)

        print(f'\tEpoch time {utils.format_time(epoch_time)}\n')
        if has_improved:
            utils.save_model(model, optimizer, f'{model_dir}/model', **model_save_kwargs)
            plateau = 0
        else:
            plateau += 1

        if (plateau == plateau_limit) or (apply_nested_dropout is True and model.has_converged()):
            break

    if apply_nested_dropout is True and model.has_converged():
        end = 'nested dropout has converged'
        print('Nested dropout has converged!')
    elif plateau == plateau_limit:
        end = 'has plateaued'
        print('The model has plateaued...')
    else:
        end = f'reached max number of epochs ({epochs})'
        print('The maximum number of epochs has been reached...')
    utils.update_save(f'{model_dir}/model', end=end)

    return losses
Beispiel #14
0
def train(loader: DataLoader, model: torch.nn.Module, criterion,
          optimizer: Optimizer, epoch: int, noise_sd: float):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs = inputs.cuda()
        targets = targets.cuda()

        # augment inputs with noise
        inputs = inputs + randgn_like(inputs, p=args.p,
                                      device='cuda') * noise_sd
        if (args.scale_down != 1):
            inputs = torch.nn.functional.interpolate(
                inputs, scale_factor=args.scale_down)
        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    return (losses.avg, top1.avg)
Beispiel #15
0
def train(
    model: Union[nn.Module, nn.DataParallel],
    train_loader: DataLoader,
    metrics: Dict[str, Metric],
    optimizer: Optimizer,
    scheduler: _LRScheduler,
    device: torch.device,
    epoch: int,
    log_interval: int,
    hooks: Optional[Sequence[Hook]] = None,
    teacher: Optional[Union[nn.Module, nn.DataParallel]] = None,
) -> Dict[str, float]:
    """
    Train a model on some data using some criterion and with some optimizer.

    Args:
        model: Model to train
        train_loader: Data loader for loading training data
        metrics: A dict mapping evaluation metric names to metrics classes
        optimizer: PyTorch optimizer
        scheduler: PyTorch scheduler
        device: PyTorch device object
        epoch: Current epoch, where the first epoch should start at 1
        log_interval: Number of batches before printing loss
        hooks: A sequence of functions that can implement custom behavior
        teacher: teacher network for knowledge distillation, if any

    Returns:
        A dictionary mapping evaluation metric names to computed values for the training set.
    """
    if hooks is None:
        hooks = []

    model.train()
    for metric in metrics.values():
        metric.reset()

    loss_fn = model.module.loss_fn if isinstance(
        model, nn.DataParallel) else model.loss_fn

    seen_examples = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        if teacher is None:
            teacher_output = None
            loss = loss_fn(output, target)  # type: ignore
        else:
            teacher_output = teacher(data)
            loss = loss_fn(output, teacher_output, target)  # type: ignore
        loss.backward()
        optimizer.step()
        project(optimizer)
        scheduler.step()  # type: ignore

        with torch.no_grad():
            for metric in metrics.values():
                metric.update(output, target, teacher_output=teacher_output)

        for hook in hooks:
            hook(
                epoch=epoch,
                global_step=1 + (epoch - 1) * len(train_loader.dataset) +
                batch_idx,
                values_dict={'lr': _get_lr(optimizer)},
                log_interval=log_interval,
            )

        seen_examples += len(data)
        if batch_idx % log_interval == 0:
            logger.info(
                'Train Epoch: {} [{}/{} ({:.0f}%)]\tBatch Loss: {:.6f}'.format(
                    epoch,
                    seen_examples,
                    len(train_loader.dataset),
                    100 * batch_idx / len(train_loader),
                    loss.item(),
                ))

    # Computing evaluation metrics for training set
    computed_metrics = {
        name: metric.compute()
        for name, metric in metrics.items()
    }

    logger.info('Training set evaluation metrics:')
    for name, metric in metrics.items():
        logger.info(f'{name}: {metric}')

    return computed_metrics
Beispiel #16
0
def routine(
    model: nn.Module,
    dataloader: data.DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer = None,
    adversary: nn.Module = None,
    inverse: bool = False,
    descents: int = None,
    flow: bool = False,
    mask_sampler: nn.Module = None,
    clip: float = None,
) -> Tuple[float, torch.Tensor]:  # (time, losses)
    r"""Training routine"""

    if adversary is None:
        adversary = Dummy()

    losses = []

    start = time()

    for theta, theta_prime, x in islice(dataloader, descents):
        y = model.embedding(x)
        adv_y = adversary.embedding(x)

        if flow:
            prob = model(theta, y)
            l = criterion(prob)
        else:
            if mask_sampler is None:
                ratio, ratio_prime = model(
                    torch.stack((theta, theta_prime)),
                    torch.stack((y, y)),
                )

                with torch.no_grad():
                    adv_ratio = adversary(theta if inverse else theta_prime,
                                          adv_y)
            else:
                if model.hyper is None:
                    mask = mask_sampler(theta.shape[:1])
                else:
                    mask = mask_sampler()

                ratio, ratio_prime = model(
                    torch.stack((theta, theta_prime)),
                    torch.stack((y, y)),
                    torch.stack((mask, mask)) if model.hyper is None else mask,
                )

                with torch.no_grad():
                    adv_ratio = adversary(theta if inverse else theta_prime,
                                          adv_y, mask)

            if adv_ratio is not None:
                adv_ratio = (-adv_ratio if inverse else adv_ratio).exp()

            if inverse:
                l = criterion(ratio, adv_ratio) + criterion(-ratio_prime)
            else:
                l = criterion(ratio) + criterion(-ratio_prime, adv_ratio)

        if not l.isfinite():
            continue

        if optimizer is not None:
            optimizer.zero_grad()
            l.backward()

            if clip is not None:
                tot = nn.utils.clip_grad_norm_(model.parameters(), clip)

                if not tot.isfinite():
                    continue

            optimizer.step()

        losses.append(l.item())

    end = time()

    return end - start, torch.tensor(losses)
Beispiel #17
0
def meta_gradient_step(model: Module,
                       optimiser: Optimizer,
                       loss_fn: Callable,
                       x: torch.Tensor,
                       y: torch.Tensor,
                       n_shot: int,
                       k_way: int,
                       q_queries: int,
                       order: int,
                       inner_train_steps: int,
                       inner_lr: float,
                       train: bool,
                       device: Union[str, torch.device]):
    """
    Perform a gradient step on a meta-learner.

    # Arguments
        model: Base model of the meta-learner being trained
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        x: Input samples for all few shot tasks
        y: Input labels of all few shot tasks
        n_shot: Number of examples per class in the support set of each task
        k_way: Number of classes in the few shot classification task of each task
        q_queries: Number of examples per class in the query set of each task. The query set is used to calculate
            meta-gradients after applying the update to
        order: Whether to use 1st order MAML (update meta-learner weights with gradients of the updated weights on the
            query set) or 2nd order MAML (use 2nd order updates by differentiating through the gradients of the updated
            weights on the query with respect to the original weights).
        inner_train_steps: Number of gradient steps to fit the fast weights during each inner update
        inner_lr: Learning rate used to update the fast weights on the inner update
        train: Whether to update the meta-learner weights at the end of the episode.
        device: Device on which to run computation
    """
    data_shape = x.shape[2:]
    create_graph = (True if order == 2 else False) and train

    task_gradients = []
    task_losses = []
    task_predictions = []
    for meta_batch in x:
        # By construction x is a 5D tensor of shape: (meta_batch_size, n*k + q*k, channels, width, height)
        # Hence when we iterate over the first  dimension we are iterating through the meta batches
        x_task_train = meta_batch[:n_shot * k_way]
        x_task_val = meta_batch[n_shot * k_way:]

        # Create a fast model using the current meta model weights
        fast_weights = OrderedDict(model.named_parameters())

        # Train the model for `inner_train_steps` iterations
        for inner_batch in range(inner_train_steps):
            # Perform update of model weights
            y = create_nshot_task_label(k_way, n_shot).to(device)
            logits = model.functional_forward(x_task_train, fast_weights)
            loss = loss_fn(logits, y)
            gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)

            # Update weights manually
            fast_weights = OrderedDict(
                (name, param - inner_lr * grad)
                for ((name, param), grad) in zip(fast_weights.items(), gradients)
            )

        # Do a pass of the model on the validation data from the current task
        y = create_nshot_task_label(k_way, q_queries).to(device)
        logits = model.functional_forward(x_task_val, fast_weights)
        loss = loss_fn(logits, y)
        loss.backward(retain_graph=True)

        # Get post-update accuracies
        y_pred = logits.softmax(dim=1)
        task_predictions.append(y_pred)

        # Accumulate losses and gradients
        task_losses.append(loss)
        gradients = torch.autograd.grad(loss, fast_weights.values(), create_graph=create_graph)
        named_grads = {name: g for ((name, _), g) in zip(fast_weights.items(), gradients)}
        task_gradients.append(named_grads)

    if order == 1:
        if train:
            sum_task_gradients = {k: torch.stack([grad[k] for grad in task_gradients]).mean(dim=0)
                                  for k in task_gradients[0].keys()}
            hooks = []
            for name, param in model.named_parameters():
                hooks.append(
                    param.register_hook(replace_grad(sum_task_gradients, name))
                )

            model.train()
            optimiser.zero_grad()
            # Dummy pass in order to create `loss` variable
            # Replace dummy gradients with mean task gradients using hooks
            logits = model(torch.zeros((k_way, ) + data_shape).to(device, dtype=torch.double))
            loss = loss_fn(logits, create_nshot_task_label(k_way, 1).to(device))
            loss.backward()
            optimiser.step()

            for h in hooks:
                h.remove()

        return torch.stack(task_losses).mean(), torch.cat(task_predictions)

    elif order == 2:
        model.train()
        optimiser.zero_grad()
        meta_batch_loss = torch.stack(task_losses).mean()

        if train:
            meta_batch_loss.backward()
            optimiser.step()

        return meta_batch_loss, torch.cat(task_predictions)
    else:
        raise ValueError('Order must be either 1 or 2.')
Beispiel #18
0
    def forward(self,
                x,
                opt: optim.Optimizer,
                step,
                summary_writer: torch.utils.tensorboard.SummaryWriter = None,
                sample_gpu=None):
        """
        train inside forward
        """
        opt.zero_grad()
        batch_size, num_pts = x.shape[:2]
        z_mu, z_sigma = self.encoder(x)
        # Compute Q(z|X) and entropy H{Q(z|X)}
        if self.use_deterministic_encoder:
            z = z_mu + 0 * z_sigma  # ? why, the original code added this 0 multiplier
            entropy = torch.zeros(batch_size).to(z)
        else:
            z = self.reparametrized_gaussian(z_mu, z_sigma)
            entropy = self.gaussian_entropy(z_sigma)

        # Compute prior P(z)
        if self.use_latent_flow:
            w, dlog_pw = self.latentCNF(z, None,
                                        torch.zeros(batch_size, 1).to(z))
            log_pw = standard_normal_logp(w).view(batch_size,
                                                  -1).sum(dim=1, keepdim=True)
            dlog_pw = dlog_pw.view(batch_size, 1).to(z)
            log_pz = log_pw - dlog_pw
        else:
            log_pz = torch.zeros(batch_size, 1).to(z)

        # Compute recon. P(X|z)
        z_new = z.view(z.shape) + (log_pz * 0.).mean()  # ? why
        y, dlog_py = self.pointCNF(x, z_new,
                                   torch.zeros(batch_size, num_pts, 1).to(x))
        log_py = standard_normal_logp(y).view(batch_size, -1).sum(dim=1,
                                                                  keepdim=True)
        dlog_py = dlog_py.view(batch_size, num_pts, 1).to(x)
        log_px = log_py - dlog_py

        # Loss
        entropy_loss = -entropy.mean() * self.entropy_w
        recon_loss = -log_px.mean() * self.recon_w
        prior_loss = -log_pz.mean() * self.prior_w
        loss = entropy_loss + recon_loss + prior_loss
        loss.backward()
        opt.step()

        # Write logs
        if self.distributed:
            raise NotImplementedError("Distributed training not implemented!")
        else:
            entropy_log = entropy.mean()
            recon_log = -log_px.mean()
            prior_log = -log_pz.mean()

        recon_nats = recon_log / float(x.size(1) * x.size(2))
        prior_nats = prior_log / float(self.fz)

        # reconstruct to save
        with torch.no_grad():
            recon_pc = self.reconstruct(x, truncate_std=True)
            recon_im = visualize(recon_pc,
                                 path='/home/tmp/screenshot.png',
                                 samples=1)

        # sample to save
        if self.use_latent_flow:
            with torch.no_grad():
                sample_pc = self.sample(1, 1024, gpu=sample_gpu)
                sample_im = visualize(sample_pc,
                                      samples=1,
                                      path='/home/tmp/screenshot.png')

        record_dict = {
            'train/entropy':
            entropy_log.cpu().detach().item()
            if not isinstance(entropy_log, float) else entropy_log,
            'train/prior':
            prior_log,
            'train/recon':
            recon_log,
            'train/recon-nats':
            recon_nats,
            'train/prior-nats':
            prior_nats,
            # 'train/sample-reconstructed': recon_pc
        }

        if summary_writer is not None:
            for key, value in record_dict:
                summary_writer.add_scalar(key, value, step)

        record_dict['train/sample-reconstructed'] = recon_im
        summary_writer.add_images('train/sample-reconstructed',
                                  recon_im,
                                  step,
                                  dataformats='NHWC')
        record_dict['train/sample-sampled'] = sample_im
        summary_writer.add_images('train/sample-sampled',
                                  sample_im,
                                  step,
                                  dataformats='NHWC')
        return record_dict
Beispiel #19
0
def train_controller(max_iter: int,
                     database: DataBase,
                     entropy_coeff: float,
                     grad_clip: int,
                     controller: NASBenchController,
                     nac: NAC,
                     optimizer: optim.Optimizer,
                     writer: tensorboard.SummaryWriter,
                     alternate_train,
                     alternate_evaluate,
                     random_baseline=False,
                     log_frequence: int = 10,
                     search_space=None):
    controller.train()
    nac.eval()
    optimizer.zero_grad()

    policy_loss_avg = MovingAverageMetric()
    entropy_mavg = MovingAverageMetric()
    logp_mavg = MovingAverageMetric()
    score_avg = MovingAverageMetric()

    pseudo_architecture_set = None

    with torch.no_grad():
        *arch_seq, _, _ = controller(force_uniform=True)
        raw_arch = seq2arch_fn(arch_seq)
        baseline_arch = [tensorize_fn(raw_arch, device=device)]

    best_collect_archs = [arch_seq]

    for iter_ in range(max_iter):

        if iter_ % args.n_iteration_update_pseudoset == 0 and args.pseudo_ratio != 0:
            if pseudo_architecture_set is None:
                pseudo_architecture_set = \
                    generate_architecture_with_pseudo_labels(
                        nac, controller,
                        2*int(args.pseudo_ratio*args.train_batch_size),
                        int(args.pseudo_ratio*args.train_batch_size))
            else:
                pseudo_architecture_set = list_concat(
                    pseudo_architecture_set,
                    generate_architecture_with_pseudo_labels(
                        nac, controller, 2 * args.n_sample_architectures,
                        args.n_sample_architectures))

            epoch = args.nac_epochs + iter_
            accuracy, rank_loss = alternate_train(
                epoch=epoch, pseudo_set=pseudo_architecture_set)
            writer.add_scalar("nac/train_accuracy", accuracy, epoch)
            writer.add_scalar("nac/loss", rank_loss, epoch)
            KTau = alternate_evaluate(epoch=epoch)
            writer.add_scalar("nac/ktau", KTau, epoch)

        *arch_seq, logp, entropy = controller()
        with torch.no_grad():
            sample_arch = [tensorize_fn(seq2arch_fn(arch_seq), device=device)]
            score = nac(batchify(sample_arch), batchify(baseline_arch))
            score = score.mean().item()

        policy_loss = -logp * score - entropy_coeff * entropy

        optimizer.zero_grad()
        if grad_clip is not None:
            nn.utils.clip_grad_norm_(controller.parameters(), grad_clip)
        policy_loss.backward()
        optimizer.step()

        policy_loss_avg.update(policy_loss)
        entropy_mavg.update(entropy)
        logp_mavg.update(logp)
        score_avg.update(score)

        if iter_ % log_frequence == 0:
            logger.info(", ".join([
                "Policy Learning",
                f"iter={iter_:03d}",
                f"policy loss={policy_loss_avg.compute():.4f}",
                f"entropy={entropy_mavg.compute():.4f}",
                f"logp={logp_mavg.compute():.4f}",
            ]))
            writer.add_scalar("policy_learning/loss",
                              policy_loss_avg.compute(), iter_)
            writer.add_scalar("policy_learning/entropy",
                              entropy_mavg.compute(), iter_)
            writer.add_scalar("policy_learning/logp", logp_mavg.compute(),
                              iter_)
            writer.add_scalar("policy_learning/reward", score_avg.compute(),
                              iter_)

        if iter_ % args.evaluate_controller_freq == 0:
            baseline_arch, best_collect_archs = derive(iter_, controller, nac,
                                                       10, database, writer,
                                                       best_collect_archs,
                                                       random_baseline,
                                                       search_space)
            torch.save(controller.state_dict(),
                       os.path.join(args.output, f"controller-{iter_}.path"))
Beispiel #20
0
def train(loader: DataLoader, model: torch.nn.Module, criterion,
          optimizer: Optimizer, epoch: int, noise_sd: float):
    """
    Function to do one training epoch
        :param loader:DataLoader: dataloader (train) 
        :param model:torch.nn.Module: the classifer being trained
        :param criterion: the loss function
        :param optimizer:Optimizer: the optimizer used during trainined
        :param epoch:int: the current epoch number (for logging)
        :param noise_sd:float: the std-dev of the Guassian noise perturbation of the input
    """
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()

    for i, (inputs, targets) in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        inputs = inputs.cuda()
        targets = targets.cuda()

        # augment inputs with noise
        inputs = inputs + torch.randn_like(inputs, device='cuda') * noise_sd

        # compute output
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(acc1.item(), inputs.size(0))
        top5.update(acc5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    return (losses.avg, top1.avg)
Beispiel #21
0
def update_w(epoch: int, data_loader, device: str, master_pair: MasterPairs,
             architecture: NASNetwork, criterion: nn.Module,
             optimizer: optim.Optimizer, force_uniform: bool,
             writer: SummaryWriter, log_frequency: int):
    start = datetime.now()
    loss_metric = AverageMetric()
    accuracy_metric = AccuracyMetric(topk=(1, 5))
    normal_logp_metric = AverageMetric()
    node_normal_entropy_metric = AverageMetric()
    op_normal_entropy_metric = AverageMetric()
    reduced_logp_metric = AverageMetric()
    node_reduced_entropy_metric = AverageMetric()
    op_reduced_entropy_metric = AverageMetric()

    master_pair.set_force_uniform(force_uniform=force_uniform)

    for iter_, (datas, targets) in enumerate(data_loader, start=1):
        datas, targets = datas.to(device=device), targets.to(device=device)
        with torch.no_grad():
            (normal_arch, normal_logp, node_normal_entropy, op_normal_entropy), \
                (reduced_arch, reduced_logp, node_reduced_entropy,
                 op_reduced_entropy) = master_pair()

        outputs = architecture(datas, normal_arch, reduced_arch)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update metrics
        loss_metric.update(loss)
        accuracy_metric.update(targets, outputs)
        normal_logp_metric.update(normal_logp)
        node_normal_entropy_metric.update(node_normal_entropy)
        op_normal_entropy_metric.update(op_normal_entropy)
        reduced_logp_metric.update(reduced_logp)
        node_reduced_entropy_metric.update(node_reduced_entropy)
        op_reduced_entropy_metric.update(op_reduced_entropy)

        # iteration log
        if iter_ % log_frequency == 0 or iter_ == len(data_loader):
            message = f"UPDATE W, epoch={epoch:03d}, iter={iter_}/{len(data_loader)}, "
            message += f"celoss={loss_metric.last:.4f}({loss_metric.value:.4f}), "
            message += f"accuracy@1={accuracy_metric.last_accuracy(1).rate*100:.2f}%"
            message += f"({accuracy_metric.accuracy(1).rate*100:.2f}%), "
            message += f"accuracy@5={accuracy_metric.last_accuracy(5).rate*100:.2f}%"
            message += f"({accuracy_metric.accuracy(5).rate*100:.2f}%), "
            message += f"normal_logp={normal_logp_metric.last:.4f}({normal_logp_metric.value:.4f}), "
            message += f"node_normal_entropy={node_normal_entropy_metric.last:.4f}({node_normal_entropy_metric.value:.4f}), "
            message += f"op_normal_entropy={op_normal_entropy_metric.last:.4f}({op_normal_entropy_metric.value:.4f}), "
            message += f"reduced_logp={reduced_logp_metric.last:.4f}({reduced_logp_metric.value:.4f}), "
            message += f"node_reduced_entropy={node_reduced_entropy_metric.last:.4f}({node_reduced_entropy_metric.value:.4f}), "
            message += f"op_reduced_entropy={op_reduced_entropy_metric.last:.4f}({op_reduced_entropy_metric.value:.4f})."
            if iter_ == len(data_loader):
                message += f" Eplased time={datetime.now()-start}."
            utils.logger.info(message)

    writer.add_scalar("update_w/celoss", loss_metric.value, epoch)
    writer.add_scalar("update_w/accuracy@1",
                      accuracy_metric.accuracy(1).rate, epoch)
    writer.add_scalar("update_w/accuracy@5",
                      accuracy_metric.accuracy(5).rate, epoch)
    writer.add_scalar("update_w/normal_logp", normal_logp_metric.value, epoch)
    writer.add_scalar("update_w/node_normal_entropy",
                      node_normal_entropy_metric.value, epoch)
    writer.add_scalar("update_w/op_normal_entropy",
                      op_normal_entropy_metric.value, epoch)
    writer.add_scalar("update_w/reduced_logp", reduced_logp_metric.value,
                      epoch)
    writer.add_scalar("update_w/node_reduced_entropy",
                      node_reduced_entropy_metric.value, epoch)
    writer.add_scalar("update_w/op_reduced_entropy",
                      op_reduced_entropy_metric.value, epoch)
def matching_net_episode(model: Module, optimiser: Optimizer, loss_fn: Loss,
                         x: torch.Tensor, y: torch.Tensor, n_shot: int,
                         k_way: int, q_queries: int, distance: str, fce: bool,
                         train: bool):
    """Performs a single training episode for a Matching Network.

    # Arguments
        model: Matching Network to be trained.
        optimiser: Optimiser to calculate gradient step from loss
        loss_fn: Loss function to calculate between predictions and outputs
        x: Input samples of few shot classification task
        y: Input labels of few shot classification task
        n_shot: Number of examples per class in the support set
        k_way: Number of classes in the few shot classification task
        q_queries: Number of examples per class in the query set
        distance: Distance metric to use when calculating distance between support and query set samples
        fce: Whether or not to us fully conditional embeddings
        train: Whether (True) or not (False) to perform a parameter update

    # Returns
        loss: Loss of the Matching Network on this task
        y_pred: Predicted class probabilities for the query set on this task
    """
    if train:
        # Zero gradients
        model.train()
        optimiser.zero_grad()
    else:
        model.eval()

    # Embed all samples
    embeddings = model.encoder(x)

    # Samples are ordered by the NShotWrapper class as follows:
    # k lots of n support samples from a particular class
    # k lots of q query samples from those classes
    support = embeddings[:n_shot * k_way]
    queries = embeddings[n_shot * k_way:]

    # Optionally apply full context embeddings
    if fce:
        # LSTM requires input of shape (seq_len, batch, input_size). `support` is of
        # shape (k_way * n_shot, embedding_dim) and we want the LSTM to treat the
        # support set as a sequence so add a single dimension to transform support set
        # to the shape (k_way * n_shot, 1, embedding_dim) and then remove the batch dimension
        # afterwards

        # Calculate the fully conditional embedding, g, for support set samples as described
        # in appendix A.2 of the paper. g takes the form of a bidirectional LSTM with a
        # skip connection from inputs to outputs
        support, _, _ = model.g(support.unsqueeze(1))
        support = support.squeeze(1)

        # Calculate the fully conditional embedding, f, for the query set samples as described
        # in appendix A.1 of the paper.
        queries = model.f(support, queries)

    # Efficiently calculate distance between all queries and all prototypes
    # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
    distances = pairwise_distances(queries, support, distance)

    # Calculate "attention" as softmax over support-query distances
    attention = (-distances).softmax(dim=1)

    # Calculate predictions as in equation (1) from Matching Networks
    # y_hat = \sum_{i=1}^{k} a(x_hat, x_i) y_i
    y_pred = matching_net_predictions(attention, n_shot, k_way, q_queries)

    # Calculated loss with negative log likelihood
    # Clip predictions for numerical stability
    clipped_y_pred = y_pred.clamp(EPSILON, 1 - EPSILON)
    loss = loss_fn(clipped_y_pred.log(), y)

    if train:
        # Backpropagate gradients
        loss.backward()
        # I found training to be quite unstable so I clip the norm
        # of the gradient to be at most 1
        clip_grad_norm_(model.parameters(), 1)
        # Take gradient step
        optimiser.step()

    return loss, y_pred
def update_model(
    optimizer: optim.Optimizer,
    scaler: amp.grad_scaler.GradScaler,
    buffer: Buffer,
    state: TSP2OPTState,
    done: bool,
    epoch: int,
    count: int,
    learn_count: int,
    global_step: int,
    logger: SummaryWriter,
    args,
):

    rewards = torch.stack(buffer.rewards, dim=0)  # [horizon, batch_size, 1]
    returns = discounted_return(rewards, args.gamma,
                                count)  # [horizon, batch_size, 1]
    if not args.no_norm_return:
        r_mean = returns.mean()
        r_std = returns.std()
        eps = torch.finfo(torch.float).eps  # small number to avoid div/0
        returns = (returns - r_mean) / (r_std + eps)
    values = torch.stack(buffer.values, dim=0)  # [horizon, batch_size, 1]
    advantages = (returns - values).detach()  # [horizon, batch_size, 1]

    logps = torch.stack(buffer.log_probs,
                        dim=0)  # [horizon, batch_size, 2, graph_size]
    actions = torch.stack(buffer.actions, dim=0)  # [horizon, batch_size, 2, 1]
    log_likelihood = logps.gather(-1, actions).squeeze(
        -1)  # [horizon, batch_size, 2]
    log_likelihood = log_likelihood.mean(2).unsqueeze(
        2)  # [horizon, batch_size, 1]

    entropies = log_p_to_entropy(logps).mean(2).unsqueeze(
        2)  # [horizon, batch_size, 1]

    p_loss = (-log_likelihood * advantages).mean()
    v_loss = args.value_beta * (returns - values).pow(2).mean()
    e_loss = (0.9**(epoch + 1)) * args.entropy_beta * entropies.sum(0).mean()
    r_loss = -e_loss + v_loss
    loss = p_loss + r_loss

    optimizer.zero_grad()
    scaler.scale(p_loss).backward(retain_graph=True)
    # scaler.unscale_(optimizer)
    grad_norms = clip_grad_norms(
        optimizer.param_groups)  #, args.max_grad_norm)
    scaler.scale(r_loss).backward(retain_graph=False)
    scaler.step(optimizer)
    scaler.update()

    buffer.clear_buffer()
    log_values(
        cost=state.best_tour_len,
        grad_norms=grad_norms,
        done=done,
        epoch=epoch,
        global_step=global_step,
        learn_count=learn_count,
        p_loss=p_loss,
        v_loss=v_loss,
        e_loss=e_loss,
        loss=loss,
        returns=returns.mean(),
        value=values.mean(),
        entropy=entropies.detach().mean(),
        logger=logger,
        args=args,
    )

    learn_count += 1

    return learn_count
Beispiel #24
0
def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: Callable,
    optimizer: optim.Optimizer,
    device: torch.device,
    train_eval_freq: int = 50,
    clip_grad_norm: float = 1.0,
    verbose: bool = True,
) -> DefaultDict[str, List[float]]:
    """
    Training loop on one epoch.

    :param nn.Module model: PyTorch Neural Network
    :param DataLoader dataloader: PyTorch DataLoader
    :param Callable criterion: PyTorch Critertion
    :param optim.Optimizer optimizer: PyTorch Optimizer
    :param torch.device device: PyTorch Device
    :param int train_eval_freq: evaluation frequency (number of batches) (default: 50)
    :param float clip_grad_norm: max_norm parameter in clip_grad_norm (default: 1.0)
    :param bool verbose: verbose (default: True)
    :return: metrics dict
    :rtype: DefaultDict[str, List[float]]
    """

    metrics: DefaultDict[str, List[float]] = defaultdict(list)

    char2idx = dataloader.dataset.char2idx

    # BOS and EOS
    bos_id = char2idx[BOS]
    eos_id = char2idx[EOS]

    if verbose:
        dataloader = tqdm(dataloader, desc="iter dataloader")

    model.train()

    for i, sentence in enumerate(dataloader):
        sentence = sentence.to(device)

        # lengths and mask
        targets = sentence[:, 1:]  # clip left
        lengths = infer_lengths(sentence, bos_id=bos_id, eos_id=eos_id)
        mask = masking(lengths + 1)  # incl. EOS

        # forward pass
        outputs = model(
            sentence[:, :-1],  # clip right
            lengths + 1,  # incl. BOS
        )
        loss_matrix = criterion(
            input=outputs.transpose(1, 2),
            target=targets,
        )
        loss = (loss_matrix * mask).sum() / mask.sum()

        # backward pass
        loss.backward()

        # clip grad norm
        grad_norm = nn.utils.clip_grad_norm_(
            model.parameters(),
            max_norm=clip_grad_norm,
        )

        # optimizer step
        optimizer.step()
        optimizer.zero_grad()

        # calculate metrics
        metrics["loss"].append(loss.item())
        metrics["grad_norm"].append(grad_norm.item())

        if verbose:
            if i % train_eval_freq == 0:
                generated_sequence = generate(
                    model=model,
                    char2idx=char2idx,
                    prefix="",
                    temperature=0.5,  # hardcoded
                    max_length=100,  # hardcoded
                )
                model.train()  # eval to train

                for metric_name, metric_list in metrics.items():
                    print(
                        f"{metric_name}: {np.mean(metric_list[-train_eval_freq:])}"
                    )
                print(f"inference: {generated_sequence}\n")

    return metrics
def train(loader: DataLoader,
          model: torch.nn.Module,
          criterion,
          optimizer: Optimizer,
          epoch: int,
          noise_sd: float,
          attacker: Attacker,
          device: torch.device,
          writer=None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_reg = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()
    requires_grad_(model, True)

    for i, batch in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        mini_batches = _chunk_minibatch(batch, args.num_noise_vec)
        for inputs, targets in mini_batches:
            inputs, targets = inputs.to(device), targets.to(device)
            batch_size = inputs.size(0)

            noises = [
                torch.randn_like(inputs, device=device) * noise_sd
                for _ in range(args.num_noise_vec)
            ]

            if args.adv_training:
                requires_grad_(model, False)
                model.eval()
                inputs = attacker.attack(model, inputs, targets, noises=noises)
                model.train()
                requires_grad_(model, True)

            # augment inputs with noise
            inputs_c = torch.cat([inputs + noise for noise in noises], dim=0)
            targets_c = targets.repeat(args.num_noise_vec)

            logits = model(inputs_c)
            loss_xent = criterion(logits, targets_c)

            logits_chunk = torch.chunk(logits, args.num_noise_vec, dim=0)
            loss_con = consistency_loss(logits_chunk, args.lbd, args.eta)

            loss = loss_xent + loss_con

            acc1, acc5 = accuracy(logits, targets_c, topk=(1, 5))
            losses.update(loss_xent.item(), batch_size)
            losses_reg.update(loss_con.item(), batch_size)
            top1.update(acc1.item(), batch_size)
            top5.update(acc5.item(), batch_size)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.avg:.3f}\t'
                  'Data {data_time.avg:.3f}\t'
                  'Loss {loss.avg:.4f}\t'
                  'Acc@1 {top1.avg:.3f}\t'
                  'Acc@5 {top5.avg:.3f}'.format(epoch,
                                                i,
                                                len(loader),
                                                batch_time=batch_time,
                                                data_time=data_time,
                                                loss=losses,
                                                top1=top1,
                                                top5=top5))

    writer.add_scalar('loss/train', losses.avg, epoch)
    writer.add_scalar('loss/consistency', losses_reg.avg, epoch)
    writer.add_scalar('batch_time', batch_time.avg, epoch)
    writer.add_scalar('accuracy/train@1', top1.avg, epoch)
    writer.add_scalar('accuracy/train@5', top5.avg, epoch)

    return (losses.avg, top1.avg)
Beispiel #26
0
def proto_net_episode(model: Module,
                      optimiser: Optimizer,
                      loss_fn: Callable,
                      x: torch.Tensor,
                      y: torch.Tensor,
                      n_shot: int,
                      k_way: int,
                      q_queries: int,
                      distance: str,
                      train: bool):
    """Performs a single training episode for a Prototypical Network.

    # Arguments
        model: Prototypical Network to be trained.
        optimiser: Optimiser to calculate gradient step
        loss_fn: Loss function to calculate between predictions and outputs. Should be cross-entropy
        x: Input samples of few shot classification task
        y: Input labels of few shot classification task
        n_shot: Number of examples per class in the support set
        k_way: Number of classes in the few shot classification task
        q_queries: Number of examples per class in the query set
        distance: Distance metric to use when calculating distance between class prototypes and queries
        train: Whether (True) or not (False) to perform a parameter update

    # Returns
        loss: Loss of the Prototypical Network on this task
        y_pred: Predicted class probabilities for the query set on this task
    """
    if train:
        # Zero gradients
        model.train()
        optimiser.zero_grad()
    else:
        model.eval()

    # Embed all samples
    embeddings = model(x)

    # Samples are ordered by the NShotWrapper class as follows:
    # k lots of n support samples from a particular class
    # k lots of q query samples from those classes

    support = embeddings[:n_shot*k_way] #[n_s X 64]
    queries = embeddings[n_shot*k_way:] #[n_f X 64] 
    
    prototypes = compute_prototypes(support, k_way, n_shot)

    # Calculate squared distances between all queries and all prototypes
    # Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
    # distances = pairwise_distances(queries, prototypes, distance)
    distances = pairwise_distances(queries, prototypes, distance)

    # Calculate log p_{phi} (y = k | x)
    log_p_y = (-distances).log_softmax(dim=1)
    loss = loss_fn(log_p_y, y)

    # Prediction probabilities are softmax over distances
    y_pred = (-distances).softmax(dim=1)

    if train:
        # Take gradient step
        loss.backward()
        optimiser.step()
    else:
        pass

    return loss, y_pred
def train_stego(*, stegoanalyser: nn.Module,
                train_iterator: DataBatchIterator,
                val_iterator: DataBatchIterator,
                text_iterator: Iterator,
                n_epoch: int, stegoanalyser_opt: Optimizer,
                callbacks: Sequence[Callable] = None, logger: TBLogger,
                encoder: SigmoidTorchEncoder):
    criterion = F.binary_cross_entropy_with_logits
    callbacks = callbacks or []

    for epoch in tqdm(range(n_epoch)):
        stegoanalyser_losses = []
        with train_iterator as iterator:
            for real_batch, _ in iterator:
                batch_size = len(real_batch)
                labels = np.random.choice([0, 1], (batch_size, 1, 1, 1))
                encoded_images = []
                for image, label in zip(real_batch, labels):
                    if label == 1:
                        msg = bytes_to_bits(next(text_iterator))
                        key = generate_random_key(image.shape[1:], len(msg))
                        image = encoder.encode(transform_encoder(image), msg, key)
                        image = inverse_transform_encoder(image)
                    encoded_images.append(image)

                encoded_images = torch.stack(encoded_images)
                labels = torch.from_numpy(labels).float()
                # train stegoanalyzer
                stegoanalyser_opt.zero_grad()
                stegoanalyser_losses.append(
                    process_batch(encoded_images.detach(), labels, stegoanalyser, criterion))
                stegoanalyser_opt.step()

        with val_iterator as iterator:
            accuracy = []
            for real_batch, _ in iterator:
                batch_size = len(real_batch)

                labels = np.random.choice([0, 1], batch_size)
                encoded_images = []
                for image, label in zip(real_batch, labels):
                    if label == 1:
                        msg = bytes_to_bits(next(text_iterator))
                        key = generate_random_key(image.shape[1:], len(msg))
                        image = encoder.encode(transform_encoder(image), msg, key)
                        image = inverse_transform_encoder(image)
                    encoded_images.append(image)

                encoded_images = torch.stack(encoded_images)
                # evaluate stegoanalyzer
                out = inference_step(encoded_images, stegoanalyser).cpu().detach()
                out = torch.sigmoid(out) > 0.5
                out = out.reshape(len(encoded_images)).numpy()
                accuracy_score = sklearn.metrics.accuracy_score(labels, out)
                accuracy.append(accuracy_score)

            mean_accuracy = np.mean(accuracy)
            print(f'validation accuracy score {mean_accuracy}')

            losses = {'Stegoanalyser loss': np.mean(stegoanalyser_losses),
                      'Val accuracy': mean_accuracy}
            logger.policies(losses, epoch)

            # run callbacks
            for callback in callbacks:
                callback(epoch)
Beispiel #28
0
def train(loader: DataLoader,
          model: torch.nn.Module,
          criterion,
          optimizer: Optimizer,
          epoch: int,
          noise_sd: float,
          attacker: Attacker = None):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    end = time.time()

    # switch to train mode
    model.train()
    requires_grad_(model, True)

    for i, batch in enumerate(loader):
        # measure data loading time
        data_time.update(time.time() - end)

        mini_batches = get_minibatches(batch, args.num_noise_vec)
        noisy_inputs_list = []
        for inputs, targets in mini_batches:
            inputs = inputs.cuda()
            targets = targets.cuda()

            inputs = inputs.repeat(
                (1, args.num_noise_vec, 1, 1)).view(batch[0].shape)

            # augment inputs with noise
            noise = torch.randn_like(inputs, device='cuda') * noise_sd

            if args.adv_training:
                requires_grad_(model, False)
                model.eval()
                inputs = attacker.attack(model,
                                         inputs,
                                         targets,
                                         noise=noise,
                                         num_noise_vectors=args.num_noise_vec,
                                         no_grad=args.no_grad_attack)
                model.train()
                requires_grad_(model, True)

            if args.train_multi_noise:
                noisy_inputs = inputs + noise
                targets = targets.unsqueeze(1).repeat(
                    1, args.num_noise_vec).reshape(-1, 1).squeeze()
                outputs = model(noisy_inputs)
                loss = criterion(outputs, targets)

                acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
                losses.update(loss.item(), noisy_inputs.size(0))
                top1.update(acc1.item(), noisy_inputs.size(0))
                top5.update(acc5.item(), noisy_inputs.size(0))

                # compute gradient and do SGD step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            else:
                inputs = inputs[::args.num_noise_vec]  # subsample the samples
                noise = noise[::args.num_noise_vec]
                # noise = torch.randn_like(inputs, device='cuda') * noise_sd
                noisy_inputs_list.append(inputs + noise)

        if not args.train_multi_noise:
            noisy_inputs = torch.cat(noisy_inputs_list)
            targets = batch[1].cuda()
            assert len(targets) == len(noisy_inputs)

            outputs = model(noisy_inputs)
            loss = criterion(outputs, targets)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
            losses.update(loss.item(), noisy_inputs.size(0))
            top1.update(acc1.item(), noisy_inputs.size(0))
            top5.update(acc5.item(), noisy_inputs.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                      epoch,
                      i,
                      len(loader),
                      batch_time=batch_time,
                      data_time=data_time,
                      loss=losses,
                      top1=top1,
                      top5=top5))

    return (losses.avg, top1.avg)
def torch_single_train(model: PyTorchForecast,
                       opt: optim.Optimizer,
                       criterion: Type[torch.nn.modules.loss._Loss],
                       data_loader: DataLoader,
                       takes_target: bool,
                       meta_data_model: PyTorchForecast,
                       meta_data_model_representation: torch.Tensor,
                       meta_loss=None,
                       multi_targets=1,
                       forward_params: Dict = {}) -> float:
    probablistic = None
    if "probabilistic" in model.params["model_params"]:
        probablistic = True
    print('running torch_single_train')
    i = 0
    output_std = None
    running_loss = 0.0
    for src, trg in data_loader:
        opt.zero_grad()
        # Convert to CPU/GPU/TPU
        src = src.to(model.device)
        trg = trg.to(model.device)
        if meta_data_model:
            representation = meta_data_model.model.generate_representation(
                meta_data_model_representation)
            forward_params["meta_data"] = representation
            if meta_loss:
                output = meta_data_model.model(meta_data_model_representation)
                met_loss = compute_loss(meta_data_model_representation, output,
                                        torch.rand(2, 3, 2), meta_loss, None)
                met_loss.backward()
        if takes_target:
            forward_params["t"] = trg
        output = model.model(src, **forward_params)
        if multi_targets == 1:
            labels = trg[:, :, 0]
        elif multi_targets > 1:
            labels = trg[:, :, 0:multi_targets]
        if probablistic:
            output1 = output
            output = output.mean
            output_std = output1.stddev
        loss = compute_loss(labels,
                            output,
                            src,
                            criterion,
                            None,
                            probablistic,
                            output_std,
                            m=multi_targets)
        if loss > 100:
            print("Warning: high loss detected")
        loss.backward()
        opt.step()
        if torch.isnan(loss) or loss == float('inf'):
            raise ValueError(
                "Error infinite or NaN loss detected. Try normalizing data or performing interpolation"
            )
        running_loss += loss.item()
        i += 1
    print("The running loss is: ")
    print(running_loss)
    print("The number of items in train is: " + str(i))
    total_loss = running_loss / float(i)
    return total_loss
Beispiel #30
0
def train_model(
    train_ds: tf.data.Dataset,
    dev_ds: tf.data.Dataset,
    model: nn.Module,
    optimizer: optim.Optimizer,
    lr_scheduler: optim.lr_scheduler._LRScheduler,
    args: argparse.Namespace,
) -> nn.Module:

    device = model_utils.get_device()
    loss_fn = model_utils.depth_proportional_loss
    val_loss_fn = model_utils.l1_norm_loss
    best_val_loss = torch.tensor(float('inf'))
    saved_checkpoints = []
    writer = SummaryWriter(log_dir=f'{args.log_dir}/{args.experiment}')

    cos = nn.CosineSimilarity(dim=1, eps=0)
    get_gradient: nn.Module = sobel.Sobel().to(device)

    for e in range(1, args.train_epochs + 1):
        print(f'Training epoch {e}...')

        if args.use_scheduler:
            lr_scheduler.step()

        # Training portion
        torch.cuda.empty_cache()
        torch.set_grad_enabled(True)
        with tqdm(total=args.train_batch_size * len(train_ds)) as progress_bar:
            model.train()
            for i, (x_batch_orig,
                    y_batch) in enumerate(train_ds.as_numpy_iterator()):
                x_batch, y_batch = model_utils.preprocess_training_example(
                    x_batch_orig, y_batch)
                y_blurred = model_utils.blur_depth_map(y_batch)

                ones = torch.ones(y_batch.shape,
                                  dtype=torch.float32,
                                  device=device)

                # Forward pass on model
                optimizer.zero_grad()
                y_pred = model(x_batch)

                depth_grad = get_gradient(y_blurred)
                output_grad = get_gradient(y_pred)
                depth_grad_dx = depth_grad[:, 0, :, :].contiguous().view_as(
                    y_blurred)
                depth_grad_dy = depth_grad[:, 1, :, :].contiguous().view_as(
                    y_batch)
                output_grad_dx = output_grad[:, 0, :, :].contiguous().view_as(
                    y_blurred)
                output_grad_dy = output_grad[:, 1, :, :].contiguous().view_as(
                    y_batch)

                depth_normal = torch.cat(
                    (-depth_grad_dx, -depth_grad_dy, ones), 1)
                output_normal = torch.cat(
                    (-output_grad_dx, -output_grad_dy, ones), 1)

                loss_depth = torch.log(torch.abs(y_pred - y_batch) +
                                       0.5).mean()
                loss_dx = torch.log(
                    torch.abs(output_grad_dx - depth_grad_dx) + 0.5).mean()
                loss_dy = torch.log(
                    torch.abs(output_grad_dy - depth_grad_dy) + 0.5).mean()
                loss_normal = torch.abs(
                    1 - cos(output_normal, depth_normal)).mean()

                loss = loss_depth + loss_normal + (loss_dx + loss_dy)

                # Backward pass and optimization
                loss.backward()
                optimizer.step()

                progress_bar.update(len(x_batch))
                progress_bar.set_postfix(loss=loss.item())
                writer.add_scalar("train/Loss", loss,
                                  ((e - 1) * len(train_ds) + i) *
                                  args.train_batch_size)

                # Periodically save a diagram
                if (i + 1) % args.picture_frequency == 0:
                    model_utils.make_diagram(
                        np.transpose(x_batch_orig, (0, 3, 1, 2)),
                        x_batch.cpu().numpy(),
                        y_batch.cpu().numpy(),
                        y_pred.cpu().detach().numpy(),
                        f'{args.save_path}/{args.experiment}/diagram_{e}_{i+1}.png',
                    )

                del x_batch
                del y_batch
                del y_blurred
                del y_pred
                del loss

        # Validation portion
        torch.cuda.empty_cache()
        torch.set_grad_enabled(False)

        with tqdm(total=args.dev_batch_size * len(dev_ds)) as progress_bar:
            model.eval()
            val_loss = 0.0
            num_batches_processed = 0
            total_pixels = 0
            total_examples = 0
            squared_error = 0
            rel_error = 0
            log_error = 0
            threshold1 = 0  # 1.25
            threshold2 = 0  # 1.25^2
            threshold3 = 0  # corresponds to 1.25^3

            for i, (x_batch, y_batch) in enumerate(dev_ds.as_numpy_iterator()):
                x_batch, y_batch = model_utils.preprocess_test_example(
                    x_batch, y_batch)
                # Forward pass on model in validation environment
                y_pred = model(x_batch)

                # TODO: Process y_pred in whatever way inference requires.
                loss = val_loss_fn(y_pred, y_batch)
                val_loss += loss.item()
                num_batches_processed += 1

                nanmask = getNanMask(y_batch)
                total_pixels = torch.sum(~nanmask)
                total_examples += x_batch.shape[0]

                # RMS, REL, LOG10, threshold calculation
                squared_error += (
                    torch.sum(torch.pow(y_pred - y_batch, 2)).item() /
                    total_pixels)**0.5
                rel_error += torch.sum(
                    removeNans(torch.abs(y_pred - y_batch) /
                               y_batch)).item() / total_pixels
                log_error += torch.sum(
                    torch.abs(
                        removeNans(torch.log10(y_pred)) - removeNans(
                            torch.log10(y_batch)))).item() / total_pixels
                threshold1 += torch.sum(
                    torch.max(y_pred / y_batch, y_batch /
                              y_pred) < 1.25).item() / total_pixels
                threshold2 += torch.sum(
                    torch.max(y_pred / y_batch, y_batch /
                              y_pred) < 1.25**2).item() / total_pixels
                threshold3 += torch.sum(
                    torch.max(y_pred / y_batch, y_batch /
                              y_pred) < 1.25**3).item() / total_pixels

                progress_bar.update(len(x_batch))
                progress_bar.set_postfix(val_loss=val_loss /
                                         num_batches_processed)
                writer.add_scalar("Val/Loss", loss,
                                  ((e - 1) * len(dev_ds) + i) *
                                  args.dev_batch_size)

                del x_batch
                del y_batch
                del y_pred
                del loss

            writer.add_scalar("Val/RMS", squared_error / total_examples, e)
            writer.add_scalar("Val/REL", rel_error / total_examples, e)
            writer.add_scalar("Val/LOG10", log_error / total_examples, e)
            writer.add_scalar("Val/delta1", threshold1 / total_examples, e)
            writer.add_scalar("Val/delta2", threshold2 / total_examples, e)
            writer.add_scalar("Val/delta3", threshold3 / total_examples, e)

            # Save model if it's the best one yet.
            if val_loss / num_batches_processed < best_val_loss:
                best_val_loss = val_loss / num_batches_processed
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_best_val.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model saved!')
                print(f'Best validation loss yet: {best_val_loss}')
            # Save model on checkpoints.
            if e % args.checkpoint_freq == 0:
                filename = f'{args.save_path}/{args.experiment}/{model.__class__.__name__}_epoch_{e}.checkpoint'
                model_utils.save_model(model, filename)
                print(f'Model checkpoint reached!')
                saved_checkpoints.append(filename)
                # Delete checkpoints if there are too many
                while len(saved_checkpoints) > args.num_checkpoints:
                    os.remove(saved_checkpoints.pop(0))

    return model