예제 #1
0
파일: greed.py 프로젝트: ProQHA/proqha
def test(model,
         loader_test,
         data_length,
         device,
         criterion,
         batch_size,
         print_logger,
         step,
         use_top5=False,
         verbose=False):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    t1 = time.time()
    with torch.no_grad():
        # switch to evaluate mode
        model.eval()
        end = time.time()

        for i, data in enumerate(loader_test):
            inputs = data[0]["data"].to(device)
            targets = data[0]["label"].squeeze().long().to(device)

            # for i, (inputs, targets) in enumerate(loader_test, 1):
            #     inputs = inputs.to(device)
            #     targets = targets.to(device)

            # compute output
            output = model(inputs)
            loss = criterion(output, targets)

            #measure accuracy and record loss
            prec1, prec5 = accuracy(output, targets, topk=(1, 5))
            losses.update(loss.item(), batch_size)
            top1.update(prec1[0], batch_size)
            top5.update(prec5[0], batch_size)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            # plot progress

        # measure elapsed time
    t2 = time.time()

    print_logger.info('Test Step [{0}]: '
                      'Loss {loss.avg:.4f} '
                      'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f} '
                      'Time {time}'.format(step,
                                           loss=losses,
                                           top1=top1,
                                           top5=top5,
                                           time=t2 - t1))

    loader_test.reset()
    return top1.avg
예제 #2
0
def test(args, loader_test, model_s):
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    cross_entropy = nn.CrossEntropyLoss()

    # switch to eval mode
    model_s.eval()

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(loader_test, 1):
            
            inputs = inputs.to(device)
            targets = targets.to(device)

            logits = model_s(inputs).to(device)
            loss = cross_entropy(logits, targets)

            prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1[0], inputs.size(0))
            top5.update(prec5[0], inputs.size(0))
        
        print_logger.info('Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
        .format(top1=top1, top5=top5))

    return top1.avg, top5.avg
def test(args, loader_test, model, criterion, writer_test, epoch=0):
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    model.eval()
    # num_iterations = len(loader_test)

    print("=> Evaluating...")
    logging.info('=> Evaluating...')
    with torch.no_grad():
        # for i, (inputs, targets) in enumerate(loader_test, 1):
        for i, data in enumerate(loader_test):
            inputs = torch.cat([data[j]["data"] for j in range(num_gpu)], dim=0)
            targets = torch.cat([data[j]["label"] for j in range(num_gpu)], dim=0).squeeze().long()

            inputs = inputs.to(args.gpus[0])
            targets = targets.to(args.gpus[0])

            logits = model(inputs)
            loss = criterion(logits, targets)

            prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))
            #print(f'* Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}')
        print(f'* Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}')
        logging.info('Top1: %e Top5: %e ', top1.avg, top5.avg)

    if not args.test_only:
        writer_test.add_scalar('test_top1', top1.avg, epoch)

    return top1.avg, top5.avg
예제 #4
0
def train(model, optimizer, trainLoader, args, epoch):

    model.train()
    losses = utils.AverageMeter()
    accurary = utils.AverageMeter()
    print_freq = len(trainLoader.dataset) // args.train_batch_size // 10
    start_time = time.time()
    for batch, (inputs, targets) in enumerate(trainLoader):

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        output = model(inputs)
        loss = loss_func(output, targets)
        loss.backward()
        losses.update(loss.item(), inputs.size(0))
        optimizer.step()

        prec1 = utils.accuracy(output, targets)
        accurary.update(prec1[0], inputs.size(0))

        if batch % print_freq == 0 and batch != 0:
            current_time = time.time()
            cost_time = current_time - start_time
            logger.info('Epoch[{}] ({}/{}):\t'
                        'Loss {:.4f}\t'
                        'Accurary {:.2f}%\t\t'
                        'Time {:.2f}s'.format(epoch,
                                              batch * args.train_batch_size,
                                              len(trainLoader.dataset),
                                              float(losses.avg),
                                              float(accurary.avg), cost_time))
            start_time = current_time
예제 #5
0
def test(args, loader_test, model, criterion, writer_test, epoch=0):
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    model.eval()
    num_iterations = len(loader_test)

    print("=> Evaluating...")
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(loader_test, 1):

            inputs = inputs.to(args.gpus[0])
            targets = targets.to(args.gpus[0])

            logits = model(inputs)
            loss = criterion(logits, targets)

            prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1[0], inputs.size(0))
            top5.update(prec5[0], inputs.size(0))

        print(f'* Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}')

    if not args.test_only:
        writer_test.add_scalar('test_top1', top1.avg, epoch)

    return top1.avg, top5.avg
예제 #6
0
def test(args, loader_test, model, criterion, writer_test, epoch):
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    model.eval()
    num_iterations = len(loader_test)

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(loader_test, 1):
            num_iters = num_iterations * epoch + i

            inputs = inputs.to(device)
            targets = targets.to(device)

            logits = model(inputs)
            loss = criterion(logits, targets)

            writer_test.add_scalar('Test_loss (fine-tuned)', loss.item(),
                                   num_iters)

            prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1[0], inputs.size(0))
            top5.update(prec5[0], inputs.size(0))

            writer_test.add_scalar('Prec@1', top1.avg, num_iters)
            writer_test.add_scalar('Prec@5', top5.avg, num_iters)

    print_logger.info(f'* Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}')
    '''
    if not args.test_only:
        writer_test.add_scalar('test_top1', top1.avg, epoch)
    '''
    return top1.avg, top5.avg
예제 #7
0
def train(args, loader_train, model, criterion, optimizer, writer_train,
          epoch):
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    model.train()
    num_iterations = len(loader_train)

    for i, (inputs, targets) in enumerate(loader_train, 1):
        num_iters = num_iterations * epoch + i

        inputs = inputs.to(args.gpus[0])
        targets = targets.to(args.gpus[0])

        logits = model(inputs)
        loss = criterion(logits, targets)

        writer_train.add_scalar('Train_loss (fine-tuned)', loss.item(),
                                num_iters)

        prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))

        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))

        writer_train.add_scalar('Prec@1', top1.avg, num_iters)
        writer_train.add_scalar('Prec@5', top5.avg, num_iters)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
예제 #8
0
def train(args, loader_train, model, criterion, optimizer, writer_train, epoch):
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    model.train()
    num_iterations = len(loader_train)

    for i, (inputs, targets) in enumerate(loader_train, 1):

        inputs = inputs.to(args.gpus[0])
        targets = targets.to(args.gpus[0])

        logits = model(inputs)
        loss = criterion(logits, targets)

        prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))

        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
예제 #9
0
def test(model, testLoader, topk=(1,)):
    model.eval()

    losses = utils.AverageMeter()
    accuracy = utils.AverageMeter()
    top5_accuracy = utils.AverageMeter()

    start_time = time.time()
    #testLoader = get_data_set('test')
    #i = 0
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(testLoader):
            #i+=1
            #if i > 5:
                #break
            inputs = batch_data[0]['data'].to(device)
            targets = batch_data[0]['label'].squeeze().long().to(device)
            targets = targets.cuda(non_blocking=True)
            outputs = model(inputs)
            loss = loss_func(outputs, targets)

            losses.update(loss.item(), inputs.size(0))
            predicted = utils.accuracy(outputs, targets, topk=topk)
            accuracy.update(predicted[0], inputs.size(0))
            top5_accuracy.update(predicted[1], inputs.size(0))

        current_time = time.time()
        logger.info(
            'Test Loss {:.4f}\tTop1 {:.2f}%\tTop5 {:.2f}%\tTime {:.2f}s\n'
                .format(float(losses.avg), float(accuracy.avg), float(top5_accuracy.avg), (current_time - start_time))
        )

    return top5_accuracy.avg, accuracy.avg
예제 #10
0
def test(model, testLoader):
    global best_acc
    model.eval()

    losses = utils.AverageMeter()
    accurary = utils.AverageMeter()

    start_time = time.time()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testLoader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = loss_func(outputs, targets)

            losses.update(loss.item(), inputs.size(0))
            predicted = utils.accuracy(outputs, targets)
            accurary.update(predicted[0], inputs.size(0))

        current_time = time.time()
        logger.info(
            'Test Loss {:.4f}\tAccurary {:.2f}%\t\tTime {:.2f}s\n'.format(
                float(losses.avg), float(accurary.avg),
                (current_time - start_time)))
    return accurary.avg
예제 #11
0
파일: main.py 프로젝트: yunhengzi/GAL
def test(args, loader_test, model_s):
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    cross_entropy = nn.CrossEntropyLoss()

    # switch to eval mode
    model_s.eval()

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(loader_test, 1):

            inputs = inputs.to(args.gpus[0])
            targets = targets.to(args.gpus[0])

            logits = model_s(inputs).to(args.gpus[0])
            loss = cross_entropy(logits, targets)

            prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1[0], inputs.size(0))
            top5.update(prec5[0], inputs.size(0))

        print('* Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(
            top1=top1, top5=top5))

        mask = []
        for name, weight in model_s.named_parameters():
            if 'mask' in name:
                mask.append(weight.item())

        print("* Pruned {} / {}".format(sum(m == 0 for m in mask), len(mask)))

    return top1.avg, top5.avg
예제 #12
0
def test(model, testLoader, topk=(1, )):
    model.eval()

    losses = utils.AverageMeter()
    accuracy = utils.AverageMeter()
    top5_accuracy = utils.AverageMeter()

    start_time = time.time()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testLoader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = loss_func(outputs, targets)

            losses.update(loss.item(), inputs.size(0))
            predicted = utils.accuracy(outputs, targets, topk=topk)
            accuracy.update(predicted[0], inputs.size(0))
            if len(topk) == 2:
                top5_accuracy.update(predicted[1], inputs.size(0))

        current_time = time.time()
        if len(topk) == 1:
            logger.info(
                'Test Loss {:.4f}\tAccuracy {:.2f}%\t\tTime {:.2f}s\n'.format(
                    float(losses.avg), float(accuracy.avg),
                    (current_time - start_time)))
        else:
            logger.info(
                'Test Loss {:.4f}\tTop1 {:.2f}%\tTop5 {:.2f}%\tTime {:.2f}s\n'.
                format(float(losses.avg), float(accuracy.avg),
                       float(top5_accuracy.avg), (current_time - start_time)))
    if len(topk) == 1:
        return accuracy.avg
    else:
        return top5_accuracy.avg
예제 #13
0
    def test(model, testLoader, topk=(1, )):
        model.eval()

        losses = utils.AverageMeter('Loss', ':.4e')
        top1_accuracy = utils.AverageMeter('Acc@1', ':6.2f')
        top5_accuracy = utils.AverageMeter('Acc@5', ':6.2f')

        start_time = time.time()

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testLoader):
                inputs, targets = inputs.to(device), targets.to(device)

                # compute output
                outputs = model(inputs)
                loss = loss_func(outputs, targets)

                # measure accuracy and record loss
                losses.update(loss.item(), inputs.size(0))
                pred = utils.accuracy(outputs, targets, topk=topk)
                top1_accuracy.update(pred[0], inputs.size(0))
                top5_accuracy.update(pred[1], inputs.size(0))

            # measure elapsed time
            current_time = time.time()
            print(
                f'Test Loss: {float(losses.avg):.6f}\t Top1: {float(top1_accuracy.avg):.6f}%\t'
                f'Top5: {float(top5_accuracy.avg):.6f}%\t Time: {float(current_time - start_time):.2f}s'
            )

        return float(top1_accuracy.avg), float(top5_accuracy.avg)
예제 #14
0
    def validate(self, epoch, val_loader, model, criterion):
        batch_time = utils.AverageMeter('Time', ':6.3f')
        losses = utils.AverageMeter('Loss', ':.4e')
        top1 = utils.AverageMeter('Acc@1', ':6.2f')
        top5 = utils.AverageMeter('Acc@5', ':6.2f')

        # switch to evaluation mode
        model.eval()
        with torch.no_grad():
            end = time.time()
            for i, (images, target) in enumerate(val_loader):
                images = images.cuda()
                target = target.cuda()

                # compute output
                logits = model(images)
                loss = criterion(logits, target)

                # measure accuracy and record loss
                pred1, pred5 = utils.accuracy(logits, target, topk=(1, 5))
                n = images.size(0)
                losses.update(loss.item(), n)
                top1.update(pred1[0], n)
                top5.update(pred5[0], n)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

            self.logger.info(
                ' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
                    top1=top1, top5=top5))

        return losses.avg, top1.avg, top5.avg
예제 #15
0
def train_class(epoch, train_loader, model, criterion, optimizer):
    batch_time = utils.AverageMeter('Time', ':6.3f')
    data_time = utils.AverageMeter('Data', ':6.3f')
    losses = utils.AverageMeter('Loss', ':.4e')
    top1 = utils.AverageMeter('Acc@1', ':6.2f')
    top5 = utils.AverageMeter('Acc@5', ':6.2f')

    model.train()
    end = time.time()

    num_iter = len(train_loader)

    print_freq = num_iter // 10
    i = 0 
    
    for batch_idx, (images, targets) in enumerate(train_loader):
        if args.debug:
            if i > 5:
                break
            i += 1
        images = images.to(device)
        targets = targets.to(device)
        data_time.update(time.time() - end)

        # compute output
        logits, mask = model(images, targets)
        loss = criterion(logits, targets)
        for m in mask:
            loss += float(args.sparse_lambda) * torch.sum(m, 0).norm(2)

        # measure accuracy and record loss
        prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
        n = images.size(0)
        losses.update(loss.item(), n)  # accumulated loss
        top1.update(prec1.item(), n)
        top5.update(prec5.item(), n)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if batch_idx % print_freq == 0 and batch_idx != 0:
            logger.info(
                'Epoch[{0}]({1}/{2}): '
                'Loss {loss.avg:.4f} '
                'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
                    epoch, batch_idx, num_iter, loss=losses,
                    top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg
예제 #16
0
    def train(self, epoch, train_loader, model, criterion, optimizer,
              scheduler):
        batch_time = utils.AverageMeter('Time', ':6.3f')
        data_time = utils.AverageMeter('Data', ':6.3f')
        losses = utils.AverageMeter('Loss', ':.4e')
        top1 = utils.AverageMeter('Acc@1', ':6.2f')
        top5 = utils.AverageMeter('Acc@5', ':6.2f')

        model.train()
        end = time.time()

        for param_group in optimizer.param_groups:
            cur_lr = param_group['lr']
        self.logger.info('learning_rate: ' + str(cur_lr))

        num_iter = len(train_loader)
        for i, (images, target) in enumerate(train_loader):
            data_time.update(time.time() - end)
            images = images.cuda()
            target = target.cuda()

            # compute outputy
            logits = model(images)
            loss = criterion(logits, target)

            # measure accuracy and record loss
            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            n = images.size(0)
            losses.update(loss.item(), n)  #accumulated loss
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.args.print_freq == 0:
                self.logger.info(
                    'Epoch[{0}]({1}/{2}): '
                    'Loss {loss.avg:.4f} '
                    'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
                        epoch, i, num_iter, loss=losses, top1=top1, top5=top5))

        scheduler.step()

        return losses.avg, top1.avg, top5.avg
예제 #17
0
def train(model, optimizer, trainLoader, args, epoch, topk=(1,)):

    model.train()
    losses = utils.AverageMeter()
    accuracy = utils.AverageMeter()
    top5_accuracy = utils.AverageMeter()
    print_freq = trainLoader._size // args.train_batch_size // 10
    start_time = time.time()
    #trainLoader = get_data_set('train')
    #i = 0
    for batch, batch_data in enumerate(trainLoader):
        #i+=1
        #if i>5:
            #break

        inputs = batch_data[0]['data'].to(device)

        targets = batch_data[0]['label'].squeeze().long().to(device)

        train_loader_len = int(math.ceil(trainLoader._size / args.train_batch_size))

        adjust_learning_rate(optimizer, epoch, batch, train_loader_len, args)


        output = model(inputs)
        loss = loss_func(output, targets)
        optimizer.zero_grad()
        loss.backward()
        losses.update(loss.item(), inputs.size(0))
        optimizer.step()

        prec1 = utils.accuracy(output, targets, topk=topk)
        accuracy.update(prec1[0], inputs.size(0))
        top5_accuracy.update(prec1[1], inputs.size(0))


        if batch % print_freq == 0 and batch != 0:
            current_time = time.time()
            cost_time = current_time - start_time
            logger.info(
                'Epoch[{}] ({}/{}):\t'
                'Loss {:.4f}\t'
                'Top1 {:.2f}%\t'
                'Top5 {:.2f}%\t'
                'Time {:.2f}s'.format(
                    epoch, batch * args.train_batch_size, trainLoader._size,
                    float(losses.avg), float(accuracy.avg), float(top5_accuracy.avg), cost_time
                )
            )
            start_time = current_time
예제 #18
0
def test(args, loader_test, model_s):
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    cross_entropy = nn.CrossEntropyLoss()

    # switch to eval mode
    model_s.eval()

    with torch.no_grad():
        #for i, (inputs, targets) in enumerate(loader_test, 1):
        for i, data in enumerate(loader_test):
            #if i > 20:
            #    break
            inputs = torch.cat([data[j]["data"] for j in range(num_gpu)],
                               dim=0)
            targets = torch.cat([data[j]["label"] for j in range(num_gpu)],
                                dim=0).squeeze().long()

            targets = targets.cuda(non_blocking=True)
            inputs = inputs.cuda()
            #inputs = inputs.to(args.gpus[0])
            #targets = targets.to(args.gpus[0])

            logits = model_s(inputs).to(args.gpus[0])
            loss = cross_entropy(logits, targets)

            prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1[0], inputs.size(0))
            top5.update(prec5[0], inputs.size(0))

        print('* Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(
            top1=top1, top5=top5))
        logging.info('Top1: %e Top5: %e ', top1.avg, top5.avg)

        mask = []
        for name, weight in model_s.named_parameters():
            if 'mask' in name:
                for i in range(len(weight)):
                    mask.append(weight[i].item())

        # num_pruned = sum(m == 0 for m in mask)
        print("* Pruned {} / {}".format(sum(m == 0 for m in mask), len(mask)))
        logging.info('Pruned: %e  Total: %e ', sum(m == 0 for m in mask),
                     len(mask))

    return top1.avg, top5.avg
예제 #19
0
파일: cifar.py 프로젝트: zyxxmu/White-Box
def train_class(model, optimizer, trainLoader, epoch, topk=(1, )):

    model.train()
    losses = utils.AverageMeter(':.4e')
    accurary = utils.AverageMeter(':6.3f')
    top5_accuracy = utils.AverageMeter(':6.3f')
    print_freq = len(trainLoader.dataset) // args.train_batch_size // 10
    start_time = time.time()
    for batch, (inputs, targets) in enumerate(trainLoader):

        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        output, mask = model(inputs, targets)
        loss = loss_func(output, targets)

        for m in mask:
            loss += float(args.sparse_lambda) * torch.sum(m, 0).norm(2)

        loss.backward()
        losses.update(loss.item(), inputs.size(0))
        optimizer.step()

        prec1 = utils.accuracy(output, targets, topk=topk)
        accurary.update(prec1[0], inputs.size(0))
        if len(topk) == 2:
            top5_accuracy.update(prec1[1], inputs.size(0))

        if batch % print_freq == 0 and batch != 0:
            current_time = time.time()
            cost_time = current_time - start_time
            if len(topk) == 1:
                logger.info('Epoch[{}] ({}/{}):\t'
                            'Loss {:.4f}\t'
                            'Accuracy {:.2f}%\t\t'
                            'Time {:.2f}s'.format(
                                epoch, batch * args.train_batch_size,
                                len(trainLoader.dataset), float(losses.avg),
                                float(accurary.avg), cost_time))
            else:
                logger.info('Epoch[{}] ({}/{}):\t'
                            'Loss {:.4f}\t'
                            'Top1 {:.2f}%\t'
                            'Top5 {:.2f}%\t'
                            'Time {:.2f}s'.format(
                                epoch, batch * args.train_batch_size,
                                len(trainLoader.dataset), float(losses.avg),
                                float(accurary.avg), float(top5_accuracy.avg),
                                cost_time))
            start_time = current_time
예제 #20
0
def train(args, loader_train, model, criterion, optimizer, writer_train, epoch, model_kd = None):
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    model.train()
    # num_iterations = len(loader_train)

    # for i, (inputs, targets) in enumerate(loader_train, 1):
    for i, data in enumerate(loader_train):
        inputs = torch.cat([data[j]["data"] for j in range(num_gpu)], dim=0)
        targets = torch.cat([data[j]["label"] for j in range(num_gpu)], dim=0).squeeze().long()


        inputs = inputs.to(args.gpus[0])
        targets = targets.to(args.gpus[0])
        logits = model(inputs)
        loss = criterion(logits, targets)

        prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        optimizer.zero_grad()
        loss.backward(retain_graph=True)

        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))
        if i % 500 == 0:
            print(f'* Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}')
            logging.info('Top1: %e Top5: %e ', top1.avg, top5.avg)

        #kd_flag = 0
        if kd_flag:
            inputs = inputs.to(args.gpus[1])
            features_kd = model_kd(inputs)
            alpha = 0.99
            Temperature = 30
            logits = logits.to(args.gpus[1])
            KD_loss = nn.KLDivLoss()(F.log_softmax(logits / Temperature, dim=1),
                                     F.softmax(features_kd / Temperature, dim=1)) * (
                                  alpha * Temperature * Temperature) #+ F.cross_entropy(logits, targets) * (1 - alpha)
            KD_loss.backward()
        # inputs = inputs.to(args.gpus[0])
        optimizer.step()
예제 #21
0
def validate(val_loader, model, criterion, args):
    batch_time = utils.AverageMeter('Time', ':6.3f')
    losses = utils.AverageMeter('Loss', ':.4e')
    top1 = utils.AverageMeter('Acc@1', ':6.2f')
    top5 = utils.AverageMeter('Acc@5', ':6.2f')

    num_iter = len(val_loader)

    model.eval()
    with torch.no_grad():
        end = time.time()
        i = 0
        for batch_idx, (images, targets) in enumerate(val_loader):
            if args.debug:
                if i > 5:
                    break
                i += 1
            images = images.cuda()
            targets = targets.cuda()

            # compute output
            logits = model(images)
            loss = criterion(logits, targets)

            # measure accuracy and record loss
            pred1, pred5 = utils.accuracy(logits, targets, topk=(1, 5))
            n = images.size(0)
            losses.update(loss.item(), n)
            top1.update(pred1[0], n)
            top5.update(pred5[0], n)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

        logger.info(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
                    .format(top1=top1, top5=top5))

    return losses.avg, top1.avg, top5.avg
예제 #22
0
    def test(model, testLoader):
        model.eval()
        losses = utils.AverageMeter('Loss', ':.4e')
        accuracy = utils.AverageMeter('Acc@1', ':6.2f')

        start_time = time.time()

        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testLoader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = loss_func(outputs, targets)

                losses.update(loss.item(), inputs.size(0))
                pred = utils.accuracy(outputs, targets)
                accuracy.update(pred[0], inputs.size(0))

            current_time = time.time()
            print(
                f'Test Loss: {float(losses.avg):.4f}\t Acc: {float(accuracy.avg):.2f}%\t\t Time: {(current_time - start_time):.2f}s'
            )
        return accuracy.avg
예제 #23
0
def test(args, loader_test, model_s, epoch):
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    cross_entropy = nn.CrossEntropyLoss()

    # switch to eval mode
    model_s.eval()

    num_iterations = len(loader_test)

    with torch.no_grad():
        for i, (inputs, targets) in enumerate(loader_test, 1):
            num_iters = num_iterations * epoch + i

            inputs = inputs.to(device)
            targets = targets.to(device)

            logits = model_s(inputs).to(device)
            loss = cross_entropy(logits, targets)

            writer_test.add_scalar('Test_loss', loss.item(), num_iters)

            prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1[0], inputs.size(0))
            top5.update(prec5[0], inputs.size(0))

            writer_test.add_scalar('Prec@1', top1.avg, num_iters)
            writer_test.add_scalar('Prec@5', top5.avg, num_iters)

    print_logger.info(
        'Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}\n'
        '===============================================\n'.format(top1=top1,
                                                                   top5=top5))

    return top1.avg, top5.avg
예제 #24
0
def test(model, topk=(1, )):
    model.eval()

    losses = utils.AverageMeter()
    accuracy = utils.AverageMeter()
    top5_accuracy = utils.AverageMeter()

    start_time = time.time()
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(testLoader):
            if len(topk) == 2:
                inputs = batch_data[0]['data'].to(device)
                targets = batch_data[0]['label'].squeeze().long().to(device)
            else:
                inputs = batch_data[0]
                targets = batch_data[1]
                inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = loss_func(outputs, targets)

            losses.update(loss.item(), inputs.size(0))
            predicted = utils.accuracy(outputs, targets, topk=topk)
            accuracy.update(predicted[0], inputs.size(0))
            if len(topk) == 2:
                top5_accuracy.update(predicted[1], inputs.size(0))

        current_time = time.time()
        if len(topk) == 1:
            print(
                'Test Loss {:.4f}\tAccuracy {:.2f}%\t\tTime {:.2f}s\n'.format(
                    float(losses.avg), float(accuracy.avg),
                    (current_time - start_time)))
        else:
            print(
                'Test Loss {:.4f}\tTop1 {:.2f}%\tTop5 {:.2f}%\tTime {:.2f}s\n'.
                format(float(losses.avg), float(accuracy.avg),
                       float(top5_accuracy.avg), (current_time - start_time)))
예제 #25
0
def bn_update(model, loader, cumulative=False):
    """
        BatchNorm buffers update (if any).
        Performs 1 epochs to estimate buffers average using train dataset.
        :param model: model being update
        :param loader: train dataset loader for buffers average estimation.
        :param cumulative: cumulative moving average or exponential moving average
        :return: approcimate train accuracy (util.AverageMeter)
    """
    if not check_bn(model):
        return

    print("approcimate process:")

    train_acc = utils.AverageMeter()
    model.train()
    model.apply(reset_bn)

    if cumulative:
        momenta = {}
        model.apply(lambda module: _get_momenta(module, momenta))
        for module in momenta.keys():
            module.momentum = None
    approcimate_num = int(len(loader) * args.approcimate_rate)
    with torch.no_grad():  # freeze all the parameters
        # with tqdm(total=approcimate_num) as pbar:
        for i, (inputs, targets) in enumerate(loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            prec1 = utils.accuracy(outputs, targets)
            # pbar.set_description(f"train accuracy:{prec1}")
            train_acc.update(prec1[0], inputs.size(0))
            # pbar.update(1)
            if i >= approcimate_num:
                break
    return train_acc
예제 #26
0
def test(args, loader_test, model, epoch=0):

    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    top1_t = utils.AverageMeter()
    top5_t = utils.AverageMeter()

    top1_s1 = utils.AverageMeter()
    top1_s2 = utils.AverageMeter()
    top1_s3 = utils.AverageMeter()
    top1_s4 = utils.AverageMeter()

    top5_s1 = utils.AverageMeter()
    top5_s2 = utils.AverageMeter()
    top5_s3 = utils.AverageMeter()
    top5_s4 = utils.AverageMeter()

    model.eval()
    num_iterations = len(loader_test)

    with torch.no_grad():
        print_logger.info("=> Evaluating...")

        for i, (inputs, targets) in enumerate(loader_test, 1):

            inputs = inputs.cuda()
            targets = targets.cuda()

            # compute output
            logits_s, logits_t = model(inputs)
            best_prec_s_1 = 0.
            for j in range(args.num_stu):
                prec1, prec5 = utils.accuracy(logits_s[j],
                                              targets,
                                              topk=(1, 5))
                eval('top1_s%d' % (j + 1)).update(prec1[0], inputs.size(0))
                eval('top5_s%d' % (j + 1)).update(prec5[0], inputs.size(0))
                if prec1 > best_prec_s_1:
                    best_prec_s_1 = prec1

                writer_test.add_scalar('test_stu_%d_top1' % (j + 1), prec1[0],
                                       num_iterations * epoch + i)

            prec1, prec5 = utils.accuracy(logits_t, targets, topk=(1, 5))
            writer_test.add_scalar('test_tea_top1', prec1[0],
                                   num_iterations * epoch + i)
            top1_t.update(prec1[0], inputs.size(0))
            top5_t.update(prec5[0], inputs.size(0))

        for j in range(args.num_stu):
            if eval('top1_s%d' % (j + 1)).avg > top1.avg:
                top1.avg = eval('top1_s%d' % (j + 1)).avg
                top5.avg = eval('top5_s%d' % (j + 1)).avg
                #best_branch = j+ 1

        print_logger.info('Epoch[{0}]({1}/{2}): '
                          'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
                              epoch, i, num_iterations, top1=top1, top5=top5))

        for i in range(args.num_stu):
            print_logger.info('top1_s%d: %.2f' %
                              (i + 1, eval('top1_s%d' % (i + 1)).avg))
        print_logger.info('top1_t: %.2f' % (top1_t.avg))

    if not args.test_only:
        writer_test.add_scalar('test_top1', top1.avg, epoch)

    return top1.avg, top5.avg
예제 #27
0
def train(args, loader_train, model, criterion, optimizer, epoch):

    losses = utils.AverageMeter()
    top1_t = utils.AverageMeter()
    top5_t = utils.AverageMeter()
    top1_s = utils.AverageMeter()
    top5_s = utils.AverageMeter()

    model.train()

    # update learning rate
    for param_group in optimizer.param_groups:
        writer_train.add_scalar('learning_rate', param_group['lr'], epoch)

    num_iterations = len(loader_train)

    for i, (inputs, targets) in enumerate(loader_train, 1):

        inputs = inputs.cuda()
        targets = targets.cuda()

        # compute output
        logits_s, logits_t = model(inputs)

        loss = criterion(logits_t, targets)
        best_prec_s_1 = torch.tensor(0.).cuda()
        best_prec_s_5 = torch.tensor(0.).cuda()
        best_branch = 1
        for j in range(args.num_stu):
            loss += criterion(logits_s[j], targets)
            loss += args.t * args.t * utils.KL(logits_t / args.t,
                                               logits_s[j] / args.t)
            prec1, prec5 = utils.accuracy(logits_s[j], targets, topk=(1, 5))
            writer_train.add_scalar('train_stu_%d_top1' % (j + 1),
                                    prec1.item(), num_iterations * epoch + i)
            if prec1 > best_prec_s_1:
                best_prec_s_1 = prec1
                best_prec_s_5 = prec5
                best_branch = j + 1

        prec1 = best_prec_s_1
        prec5 = best_prec_s_5

        prec1_t, prec5_t = utils.accuracy(logits_t, targets, topk=(1, 5))
        top1_t.update(prec1_t.item(), inputs.size(0))
        top5_t.update(prec5_t.item(), inputs.size(0))

        losses.update(loss.item(), inputs.size(0))

        writer_train.add_scalar('train_top1', prec1.item(),
                                num_iterations * epoch + i)
        writer_train.add_scalar('train_loss', loss.item(),
                                num_iterations * epoch + i)

        top1_s.update(prec1.item(), inputs.size(0))
        top5_s.update(prec5.item(), inputs.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        if i % args.print_freq == 0:
            print_logger.info(
                'Epoch[{0}]({1}/{2}): '
                'Loss {loss.avg:.4f} '
                'TeacherPrec@1(1,5) {top1_t.avg:.2f}, {top5_t.avg:.2f} '
                'StuPrec@1(1,5) {top1_s.avg:.2f}, {top5_s.avg:.2f} '
                'Best branch: {best_branch: d}'.format(
                    epoch,
                    i,
                    num_iterations,
                    loss=losses,
                    top1_t=top1_t,
                    top5_t=top5_t,
                    top1_s=top1_s,
                    top5_s=top5_s,
                    best_branch=best_branch))

    return losses.avg, top1_s.avg
예제 #28
0
def train(args, loader_train, models, optimizers, epoch, writer_train):
    #losses_d = utils.AverageMeter()
    #losses_data = utils.AverageMeter()
    #losses_g = utils.AverageMeter()
    #losses_sparse = utils.AverageMeter()
    #losses_kl = utils.AverageMeter()

    top1 = utils.AverageMeter()
    top5 = utils.AverageMeter()

    model_t = models[0]
    model_s = models[1]
    model_d = models[2]
    #model_kd = models[3]

    bce_logits = nn.BCEWithLogitsLoss()

    optimizer_d = optimizers[0]
    optimizer_s = optimizers[1]
    optimizer_m = optimizers[2]

    # switch to train mode
    model_d.train()
    model_s.train()
    num_iterations = int(loader_train._size / batch_sizes)
    #num_iterations = len(loader_train)
    print(num_iterations)
    real_label = 1
    fake_label = 0
    exact_list = ["layer3"]
    num_pruned = -1
    t0 = time.time()
    '''
    prec1 = [60]
    #prec1 = 0
    error_d = 0
    error_sparse = 0
    error_g = 0
    error_data = 0
    KD_loss = 0

    alpha_d = args.miu * ( 0.9 - epoch / args.num_epochs * 0.9 )
    sparse_lambda = args.sparse_lambda
    mask_step = args.mask_step
    lr_decay_step = args.lr_decay_step
    '''
    #for i, (inputs, targets) in enumerate(loader_train, 1):
    for i, data in enumerate(loader_train):

        global iteration
        iteration = i

        tt0 = time.time()
        if i % 60 == 1:
            t0 = time.time()

        if i % 400 == 1:
            num_mask = []
            for name, weight in model_s.named_parameters():
                if 'mask' in name:
                    for ii in range(len(weight)):
                        num_mask.append(weight[ii].item())
            num_pruned = sum(m == 0 for m in num_mask)
            if num_pruned > 1100:
                iteration = 1

        #num_iters = num_iterations * epoch + i

        if i > 100 and top1.val < 30:
            iteration = 1
        #iteration = 2
        gl.set_value('iteration', iteration)

        inputs = torch.cat([data[j]["data"] for j in range(num_gpu)], dim=0)
        targets = torch.cat([data[j]["label"] for j in range(num_gpu)],
                            dim=0).squeeze().long()

        targets = targets.cuda(non_blocking=True)
        inputs = inputs.cuda()

        #inputs = inputs.to(args.gpus[0])
        #targets = targets.to(args.gpus[0])
        features_t = model_t(inputs)
        features_s = model_s(inputs)
        #features_kd = model_kd(inputs)

        ############################
        # (1) Update
        # D network
        ###########################
        #'''
        for p in model_d.parameters():
            p.requires_grad = True

        optimizer_d.zero_grad()

        output_t = model_d(features_t.to(args.gpus[0]).detach())

        labels_real = torch.full_like(output_t,
                                      real_label,
                                      device=args.gpus[0])
        error_real = bce_logits(output_t, labels_real)

        output_s = model_d(features_s.to(args.gpus[0]).detach())

        labels_fake = torch.full_like(output_t,
                                      fake_label,
                                      device=args.gpus[0])
        error_fake = bce_logits(output_s, labels_fake)

        error_d = 0.1 * error_real + 0.1 * error_fake

        labels = torch.full_like(output_s, real_label, device=args.gpus[0])

        #error_d += bce_logits(output_s, labels)
        error_d.backward()

        #losses_d.update(error_d.item(), inputs.size(0))
        #writer_train.add_scalar(
        #'discriminator_loss', error_d.item(), num_iters)

        optimizer_d.step()
        #if i % args.print_freq == 0:#i >= 0:#
        if i < 0:
            print('=> D_Epoch[{0}]({1}/{2}):\t'
                  'Loss_d {loss_d.val:.4f} ({loss_d.avg:.4f})\t'.format(
                      epoch, i, num_iterations, loss_d=losses_d))

        #'''
        ############################
        # (2) Update student network
        ###########################

        #'''

        for p in model_d.parameters():
            p.requires_grad = False

        optimizer_s.zero_grad()
        optimizer_m.zero_grad()

        alpha = 0.9 - epoch / args.num_epochs * 0.9
        Temperature = 10
        KD_loss = 10 * nn.KLDivLoss()(
            F.log_softmax(features_s / Temperature, dim=1),
            F.softmax(features_t / Temperature, dim=1)) * (
                alpha * Temperature * Temperature) + F.cross_entropy(
                    features_s, targets) * (1 - alpha)
        KD_loss.backward(retain_graph=True)
        #losses_kl.update(KD_loss.item(), inputs.size(0))

        # data_loss
        alpha = 0.9 - epoch / args.num_epochs * 0.9
        #one_hot = torch.zeros(targets.shape[0], 1000).cuda()
        #one_hot = one_hot.scatter_(1, targets.reshape(targets.shape[0],1), 1).cuda()
        error_data = args.miu * (
            alpha * F.mse_loss(features_t, features_s.to(args.gpus[0]))
        )  # + (1 - alpha) * F.mse_loss(one_hot, features_s.to(args.gpus[0])))
        #losses_data.update(error_data.item(), inputs.size(0))
        error_data.backward(retain_graph=True)

        # fool discriminator
        #tt3 = time.time()
        output_s = model_d(features_s.to(args.gpus[0]))
        labels = torch.full_like(output_s, real_label, device=args.gpus[0])
        error_g = 0.1 * bce_logits(output_s, labels)
        #losses_g.update(error_g.item(), inputs.size(0))
        #writer_train.add_scalar(
        #'generator_loss', error_g.item(), num_iters)
        error_g.backward(retain_graph=True)

        optimizer_s.step()

        #'''

        # train mask
        error_sparse = 0
        decay = (epoch % args.lr_decay_step == 0 and i == 1)
        if i % (args.mask_step) == 0:
            mask = []
            for name, param in model_s.named_parameters():
                if 'mask' in name:
                    mask.append(param.view(-1))
            mask = torch.cat(mask)
            error_sparse = 0.00001 * args.sparse_lambda * F.l1_loss(
                mask,
                torch.zeros(mask.size()).to(args.gpus[0]),
                reduction='sum')
            error_sparse.backward()
            optimizer_m.step(decay)
            #losses_sparse.update(error_sparse.item(), inputs.size(0))
            #writer_train.add_scalar(
            #'sparse_loss', error_sparse.item(), num_iters)
        prec1, prec5 = utils.accuracy(features_s.to(args.gpus[0]),
                                      targets.to(args.gpus[0]),
                                      topk=(1, 5))
        top1.update(prec1[0], inputs.size(0))
        top5.update(prec5[0], inputs.size(0))

        if i % 60 == 0:
            t1 = time.time()
            print('=> G_Epoch[{0}]({1}/{2}):\n'
                  'Loss_s {loss_sparse:.4f} \t'
                  'Loss_data {loss_data:.4f}\t'
                  'Loss_d {loss_d:.4f} \n'
                  'Loss_g {loss_g:.4f} \t'
                  'Loss_kl {loss_kl:.4f} \n'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\n'
                  'time {time:.4f}\t'
                  'pruned {np}'.format(epoch,
                                       i,
                                       num_iterations,
                                       loss_sparse=error_sparse,
                                       loss_data=error_data,
                                       loss_d=error_d,
                                       loss_g=error_g,
                                       loss_kl=KD_loss,
                                       top1=top1,
                                       top5=top5,
                                       time=t1 - t0,
                                       np=num_pruned))
            logging.info(
                'TRAIN epoch: %03d step : %03d  Top1: %e Top5: %e error_g: %e error_data: %e error_d: %e Duration: %f Pruned: %d',
                epoch, i, top1.avg, top5.avg, error_g, error_data, error_d,
                t1 - t0, num_pruned)
예제 #29
0
def calculationFitness(honey, args):
    global best_honey
    global best_honey_state


    if args.arch == 'vgg':
        model = import_module(f'model.{args.arch}').BeeVGG(honeysource=honey, num_classes=1000).to(device)
        load_vgg_honey_model(model, args.random_rule)
    elif args.arch == 'resnet':
        model = import_module(f'model.{args.arch}').resnet(args.cfg,honey=honey).to(device)
        load_resnet_honey_model(model, args.random_rule)
    elif args.arch == 'googlenet':
        pass
    elif args.arch == 'densenet':
        pass

    #start_time = time.time()
    if len(args.gpus) != 1:
        model = nn.DataParallel(model, device_ids=args.gpus)

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)

    #test(model, testLoader)

    model.train()

    #trainLoader = get_data_set('train')
    #i = 0
    for epoch in range(args.calfitness_epoch):
        #print(epoch)
        for batch, batch_data in enumerate(trainLoader):
            #i += 1
            #print(i)
            #if i > 5:
                #break
            #if i < 10:
            #   continue
            #i = 0
            inputs = batch_data[0]['data'].to(device)
            targets = batch_data[0]['label'].squeeze().long().to(device)

            train_loader_len = int(math.ceil(trainLoader._size / args.train_batch_size))

            adjust_learning_rate(optimizer, epoch, batch, train_loader_len, args)

            #print('epoch{}\tlr{}'.format(epoch,lr))
            
            optimizer.zero_grad()
            output = model(inputs)
            loss = loss_func(output, targets)
            loss.backward()
            optimizer.step()

        trainLoader.reset()

    #test(model, loader.testLoader)

    fit_accurary = utils.AverageMeter()
    model.eval()
    #testLoader = get_data_set('test')
    #i = 0
    with torch.no_grad():
        for batch_idx, batch_data in enumerate(testLoader):
            #print(i)
            #i += 1
            #if i > 5:
                #reak
            #if i < 10:
                #continue
            #i = 0
            inputs = batch_data[0]['data'].to(device)
            targets = batch_data[0]['label'].squeeze().long().to(device)
            outputs = model(inputs)
            predicted = utils.accuracy(outputs, targets,topk=(1,5))
            fit_accurary.update(predicted[1], inputs.size(0))
    testLoader.reset()


    #current_time = time.time()
    '''
    logger.info(
            'Honey Source fintness {:.2f}%\t\tTime {:.2f}s\n'
            .format(float(accurary.avg), (current_time - start_time))
        )
    '''
    if fit_accurary.avg == 0:
        fit_accurary.avg = 0.01

    if fit_accurary.avg > best_honey.fitness:
        best_honey_state = copy.deepcopy(model.module.state_dict() if len(args.gpus) > 1 else model.state_dict())
        best_honey.code = copy.deepcopy(honey)
        best_honey.fitness = fit_accurary.avg

    return fit_accurary.avg
예제 #30
0
def train(epoch, train_loader, model, criterion, optimizer):
    batch_time = utils.AverageMeter('Time', ':6.3f')
    data_time = utils.AverageMeter('Data', ':6.3f')
    losses = utils.AverageMeter('Loss', ':.4e')
    top1 = utils.AverageMeter('Acc@1', ':6.2f')
    top5 = utils.AverageMeter('Acc@5', ':6.2f')

    model.train()
    end = time.time()

    if args.use_dali:
        num_iter = train_loader._size // args.train_batch_size
    else:
        num_iter = len(train_loader)

    print_freq = num_iter // 10
    i = 0
    if args.use_dali:
        for batch_idx, batch_data in enumerate(train_loader):
            if args.debug:
                if i > 5:
                    break
                i += 1
            images = batch_data[0]['data'].cuda()
            targets = batch_data[0]['label'].squeeze().long().cuda()
            data_time.update(time.time() - end)

            adjust_learning_rate(optimizer, epoch, batch_idx, num_iter)

            # compute output
            logits = model(images)
            loss = loss_func(logits, targets)

            # measure accuracy and record loss
            prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
            n = images.size(0)
            losses.update(loss.item(), n)  # accumulated loss
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % print_freq == 0 and batch_idx != 0:
                logger.info(
                    'Epoch[{0}]({1}/{2}): '
                    'Loss {loss.avg:.4f} '
                    'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
                        epoch,
                        batch_idx,
                        num_iter,
                        loss=losses,
                        top1=top1,
                        top5=top5))
    else:
        for batch_idx, (images, targets) in enumerate(train_loader):
            if args.debug:
                if i > 5:
                    break
                i += 1
            images = images.cuda()
            targets = targets.cuda()
            data_time.update(time.time() - end)

            adjust_learning_rate(optimizer, epoch, batch_idx, num_iter)

            # compute output
            logits = model(images)
            loss = loss_func(logits, targets)

            # measure accuracy and record loss
            prec1, prec5 = utils.accuracy(logits, targets, topk=(1, 5))
            n = images.size(0)
            losses.update(loss.item(), n)  # accumulated loss
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if batch_idx % print_freq == 0 and batch_idx != 0:
                logger.info(
                    'Epoch[{0}]({1}/{2}): '
                    'Loss {loss.avg:.4f} '
                    'Prec@1(1,5) {top1.avg:.2f}, {top5.avg:.2f}'.format(
                        epoch,
                        batch_idx,
                        num_iter,
                        loss=losses,
                        top1=top1,
                        top5=top5))

    return losses.avg, top1.avg, top5.avg