コード例 #1
0
def validate(val_loader, model, criterion, args, device):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(val_loader),
                             batch_time,
                             losses,
                             top1,
                             top5,
                             prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1.item(), images.size(0))
            top5.update(acc5.item(), images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                progress.print(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1,
                                                                    top5=top5))

    return top1.avg
コード例 #2
0
def train(train_loader,
          model,
          criterion,
          optimizer,
          epoch,
          args,
          device,
          ml_logger,
          val_loader,
          mq=None):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             batch_time,
                             data_time,
                             losses,
                             top1,
                             top5,
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    best_acc1 = -1
    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1.item(), images.size(0))
        top5.update(acc5.item(), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.print(i)
            ml_logger.log_metric('Train Acc1',
                                 top1.avg,
                                 step='auto',
                                 log_to_tfboard=False)
            ml_logger.log_metric('Train Loss',
                                 losses.avg,
                                 step='auto',
                                 log_to_tfboard=False)
コード例 #3
0
def train(train_loader,
          model,
          criterion,
          optimizer,
          epoch,
          args,
          device,
          ml_logger,
          val_loader,
          mq=None,
          weight_to_hook=None,
          w_k_scale=0):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    w_k_losses = AverageMeter('W_K_Loss', ':.4e')
    w_k_vals = AverageMeter('W_K_Val', ':6.2f')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(len(train_loader),
                             batch_time,
                             data_time,
                             losses,
                             w_k_losses,
                             w_k_vals,
                             top1,
                             top5,
                             prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()
    best_acc1 = -1
    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        hookF_weights = {}
        for name, w_tensor in weight_to_hook.items():
            # pdb.set_trace()
            hookF_weights[name] = KurtosisWeight(
                w_tensor,
                name,
                kurtosis_target=args.w_kurtosis_target,
                k_mode=args.kurtosis_mode)

        # compute output
        output = model(images)

        w_kurtosis_regularization = 0
        # pdb.set_trace()
        if args.w_kurtosis:
            w_temp_values = []
            w_kurtosis_loss = 0
            for w_kurt_inst in hookF_weights.values():
                # pdb.set_trace()
                w_kurt_inst.fn_regularization()
                w_temp_values.append(w_kurt_inst.kurtosis_loss)
            # pdb.set_trace()
            if args.kurtosis_mode == 'sum':
                w_kurtosis_loss = reduce((lambda a, b: a + b), w_temp_values)
            elif args.kurtosis_mode == 'avg':
                # pdb.set_trace()
                w_kurtosis_loss = reduce((lambda a, b: a + b), w_temp_values)
                if args.arch == 'resnet18':
                    w_kurtosis_loss = w_kurtosis_loss / 19
                elif args.arch == 'mobilenet_v2':
                    w_kurtosis_loss = w_kurtosis_loss / 51
                elif args.arch == 'resnet50':
                    w_kurtosis_loss = w_kurtosis_loss / 52
            elif args.kurtosis_mode == 'max':
                # pdb.set_trace()
                w_kurtosis_loss = reduce((lambda a, b: max(a, b)),
                                         w_temp_values)
            w_kurtosis_regularization = (
                10**w_k_scale) * args.w_lambda_kurtosis * w_kurtosis_loss

        orig_loss = criterion(output, target)
        loss = orig_loss + w_kurtosis_regularization

        if args.w_kurtosis:
            w_temp_values = []
            for w_kurt_inst in hookF_weights.values():
                w_kurt_inst.fn_regularization()
                w_temp_values.append(w_kurt_inst.kurtosis)
            w_kurtosis_val = reduce((lambda a, b: a + b), w_temp_values)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        w_k_losses.update(w_kurtosis_regularization.item(), images.size(0))
        w_k_vals.update(w_kurtosis_val.item(), images.size(0))
        top1.update(acc1.item(), images.size(0))
        top5.update(acc5.item(), images.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.print(i)
            ml_logger.log_metric('Train Acc1',
                                 top1.avg,
                                 step='auto',
                                 log_to_tfboard=False)
            ml_logger.log_metric('Train Loss',
                                 losses.avg,
                                 step='auto',
                                 log_to_tfboard=False)
            ml_logger.log_metric('Train weight kurtosis Loss',
                                 w_k_losses.avg,
                                 step='auto',
                                 log_to_tfboard=False)
            ml_logger.log_metric('Train weight kurtosis Val',
                                 w_k_vals.avg,
                                 step='auto',
                                 log_to_tfboard=False)

        for w_kurt_inst in hookF_weights.values():
            del w_kurt_inst