Esempio n. 1
0
def run_training(args):
    # create model
    training_loss = 0
    training_acc = 0

    model = models.__dict__[args.arch](args.pretrained)
    model = torch.nn.DataParallel(model).cuda()

    if args.swa_start is not None:
        print('SWA training')
        swa_model = torch.nn.DataParallel(models.__dict__[args.arch](
            args.pretrained)).cuda()
        swa_n = 0

    else:
        print('SGD training')

    best_prec1 = 0
    best_iter = 0

    best_swa_prec = 0
    best_swa_iter = 0

    # best_full_prec = 0

    if args.resume:
        if os.path.isfile(args.resume):
            logging.info('=> loading checkpoint `{}`'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_iter = checkpoint['iter']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])

            if args.swa_start is not None:
                swa_state_dict = checkpoint['swa_state_dict']
                if swa_state_dict is not None:
                    swa_model.load_state_dict(swa_state_dict)
                swa_n_ckpt = checkpoint['swa_n']
                if swa_n_ckpt is not None:
                    swa_n = swa_n_ckpt
                best_swa_prec_ckpt = checkpoint['best_swa_prec']
                if best_swa_prec_ckpt is not None:
                    best_swa_prec = best_swa_prec_ckpt

            logging.info('=> loaded checkpoint `{}` (iter: {})'.format(
                args.resume, checkpoint['iter']))
        else:
            logging.info('=> no checkpoint found at `{}`'.format(args.resume))

    cudnn.benchmark = False

    train_loader = prepare_train_data(dataset=args.dataset,
                                      datadir=args.datadir,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.workers)
    test_loader = prepare_test_data(dataset=args.dataset,
                                    datadir=args.datadir,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=args.workers)
    if args.swa_start is not None:
        swa_loader = prepare_train_data(dataset=args.dataset,
                                        datadir=args.datadir,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.workers)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optimizer = torch.optim.Adam(model.parameters(), args.lr,
    #                             weight_decay=args.weight_decay)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    cr = AverageMeter()

    end = time.time()

    global scale_loss
    global turning_point_count
    global my_loss_diff_indicator

    i = args.start_iter
    while i < args.iters:
        for input, target in train_loader:
            # measuring data loading time
            data_time.update(time.time() - end)

            model.train()
            adjust_learning_rate(args, optimizer, i)
            # adjust_precision(args, i)
            adaptive_adjust_precision(args, turning_point_count)

            i += 1

            fw_cost = args.num_bits * args.num_bits / 32 / 32
            eb_cost = args.num_bits * args.num_grad_bits / 32 / 32
            gc_cost = eb_cost
            cr.update((fw_cost + eb_cost + gc_cost) / 3)

            target = target.squeeze().long().cuda()
            input_var = Variable(input).cuda()
            target_var = Variable(target).cuda()

            # compute output
            output = model(input_var, args.num_bits, args.num_grad_bits)
            loss = criterion(output, target_var)
            training_loss += loss.item()

            # measure accuracy and record loss
            prec1, = accuracy(output.data, target, topk=(1, ))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            training_acc += prec1.item()

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # print log
            if i % args.print_freq == 0:
                logging.info(
                    "Iter: [{0}/{1}]\t"
                    "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                    "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                    "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
                    "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format(
                        i,
                        args.iters,
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses,
                        top1=top1))

            if args.swa_start is not None and i >= args.swa_start and i % args.swa_freq == 0:
                util_swa.moving_average(swa_model, model, 1.0 / (swa_n + 1))
                swa_n += 1
                util_swa.bn_update(swa_loader, swa_model, args.num_bits,
                                   args.num_grad_bits)
                prec1 = validate(args,
                                 test_loader,
                                 swa_model,
                                 criterion,
                                 i,
                                 swa=True)

                if prec1 > best_swa_prec:
                    best_swa_prec = prec1
                    best_swa_iter = i

                # print("Current Best SWA Prec@1: ", best_swa_prec)
                # print("Current Best SWA Iteration: ", best_swa_iter)

            if (i % args.eval_every == 0 and i > 0) or (i == args.iters):
                # record training loss and test accuracy
                global history_score
                epoch = i // args.eval_every
                epoch_loss = training_loss / len(train_loader)
                with torch.no_grad():
                    prec1 = validate(args, test_loader, model, criterion, i)
                    # prec_full = validate_full_prec(args, test_loader, model, criterion, i)
                history_score[epoch - 1][0] = epoch_loss
                history_score[epoch - 1][1] = np.round(
                    training_acc / len(train_loader), 2)
                history_score[epoch - 1][2] = prec1
                training_loss = 0
                training_acc = 0

                np.savetxt(os.path.join(save_path, 'record.txt'),
                           history_score,
                           fmt='%10.5f',
                           delimiter=',')

                # apply indicator
                # if epoch == 1:
                #     logging.info('initial loss value: {}'.format(epoch_loss))
                #     my_loss_diff_indicator.scale_loss = epoch_loss
                if epoch <= 10:
                    scale_loss += epoch_loss
                    logging.info('scale_loss at epoch {}: {}'.format(
                        epoch, scale_loss / epoch))
                    my_loss_diff_indicator.scale_loss = scale_loss / epoch
                if turning_point_count < args.num_turning_point:
                    my_loss_diff_indicator.get_loss(epoch_loss)
                    flag = my_loss_diff_indicator.turning_point_emerge()
                    if flag == True:
                        turning_point_count += 1
                        logging.info(
                            'find {}-th turning point at {}-th epoch'.format(
                                turning_point_count, epoch))
                        # print('find {}-th turning point at {}-th epoch'.format(turning_point_count, epoch))
                        my_loss_diff_indicator.adaptive_threshold(
                            turning_point_count=turning_point_count)
                        my_loss_diff_indicator.reset()

                logging.info(
                    'Epoch [{}] num_bits = {} num_grad_bits = {}'.format(
                        epoch, args.num_bits, args.num_grad_bits))

                # print statistics
                is_best = prec1 > best_prec1
                if is_best:
                    best_prec1 = prec1
                    best_iter = i
                # best_full_prec = max(prec_full, best_full_prec)

                logging.info("Current Best Prec@1: {}".format(best_prec1))
                logging.info("Current Best Iteration: {}".format(best_iter))
                logging.info(
                    "Current Best SWA Prec@1: {}".format(best_swa_prec))
                logging.info(
                    "Current Best SWA Iteration: {}".format(best_swa_iter))
                # print("Current Best Full Prec@1: ", best_full_prec)

                # checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(i, prec1))
                checkpoint_path = os.path.join(args.save_path, 'ckpt.pth.tar')
                save_checkpoint(
                    {
                        'iter':
                        i,
                        'arch':
                        args.arch,
                        'state_dict':
                        model.state_dict(),
                        'best_prec1':
                        best_prec1,
                        'swa_state_dict':
                        swa_model.state_dict()
                        if args.swa_start is not None else None,
                        'swa_n':
                        swa_n if args.swa_start is not None else None,
                        'best_swa_prec':
                        best_swa_prec if args.swa_start is not None else None,
                    },
                    is_best,
                    filename=checkpoint_path)
                # shutil.copyfile(checkpoint_path, os.path.join(args.save_path,
                # 'checkpoint_latest'
                # '.pth.tar'))

                if i == args.iters:
                    print("Best accuracy: " + str(best_prec1))
                    history_score[-1][0] = best_prec1
                    np.savetxt(os.path.join(save_path, 'record.txt'),
                               history_score,
                               fmt='%10.5f',
                               delimiter=',')
                    break
Esempio n. 2
0
def run_training(args):
    # create model
    model = models.__dict__[args.arch](args.pretrained)
    model = torch.nn.DataParallel(model).cuda()

    if args.swa_start is not None:
        print('SWA training')
        swa_model = torch.nn.DataParallel(models.__dict__[args.arch](
            args.pretrained)).cuda()
        swa_n = 0

    else:
        print('SGD training')

    best_prec1 = 0
    best_iter = 0

    best_swa_prec = 0
    best_swa_iter = 0

    if args.resume:
        if os.path.isfile(args.resume):
            logging.info('=> loading checkpoint `{}`'.format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_iter = checkpoint['iter']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])

            if args.swa_start is not None:
                swa_state_dict = checkpoint['swa_state_dict']
                if swa_state_dict is not None:
                    swa_model.load_state_dict(swa_state_dict)
                swa_n_ckpt = checkpoint['swa_n']
                if swa_n_ckpt is not None:
                    swa_n = swa_n_ckpt
                best_swa_prec_ckpt = checkpoint['best_swa_prec']
                if best_swa_prec_ckpt is not None:
                    best_swa_prec = best_swa_prec_ckpt

            logging.info('=> loaded checkpoint `{}` (iter: {})'.format(
                args.resume, checkpoint['iter']))
        else:
            logging.info('=> no checkpoint found at `{}`'.format(args.resume))

    cudnn.benchmark = False

    train_loader = prepare_train_data(dataset=args.dataset,
                                      datadir=args.datadir,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.workers)
    test_loader = prepare_test_data(dataset=args.dataset,
                                    datadir=args.datadir,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=args.workers)
    if args.swa_start is not None:
        swa_loader = prepare_train_data(dataset=args.dataset,
                                        datadir=args.datadir,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.workers)

    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    cr = AverageMeter()

    end = time.time()

    i = args.start_iter
    while i < args.iters:
        for input, target in train_loader:
            # measuring data loading time
            data_time.update(time.time() - end)

            model.train()
            adjust_learning_rate(args, optimizer, i)

            cyclic_period = int(args.iters / args.num_cyclic_period)
            cyclic_adjust_precision(args, i, cyclic_period)

            i += 1

            fw_cost = args.num_bits * args.num_bits / 32 / 32
            eb_cost = args.num_bits * args.num_grad_bits / 32 / 32
            gc_cost = eb_cost
            cr.update((fw_cost + eb_cost + gc_cost) / 3)
            target = target.squeeze().long().cuda()
            input_var = Variable(input).cuda()
            target_var = Variable(target).cuda()

            # compute output
            output = model(input_var, args.num_bits, args.num_grad_bits)
            loss = criterion(output, target_var)

            # measure accuracy and record loss
            prec1, = accuracy(output.data, target, topk=(1, ))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            # print log
            if i % args.print_freq == 0:
                logging.info("Num bit {}\t"
                             "Num grad bit {}\t".format(
                                 args.num_bits, args.num_grad_bits))
                logging.info(
                    "Iter: [{0}/{1}]\t"
                    "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                    "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                    "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
                    "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t"
                    "Training FLOPS ratio: {cr.val:.6f} ({cr.avg:.6f})\t".
                    format(i,
                           args.iters,
                           batch_time=batch_time,
                           data_time=data_time,
                           loss=losses,
                           top1=top1,
                           cr=cr))

            if args.swa_start is not None and i >= args.swa_start and i % args.swa_freq == 0:
                util_swa.moving_average(swa_model, model, 1.0 / (swa_n + 1))
                swa_n += 1
                util_swa.bn_update(swa_loader, swa_model, args.num_bits,
                                   args.num_grad_bits)
                prec1 = validate(args,
                                 test_loader,
                                 swa_model,
                                 criterion,
                                 i,
                                 swa=True)

                if prec1 > best_swa_prec:
                    best_swa_prec = prec1
                    best_swa_iter = i

                print("Current Best SWA Prec@1: ", best_swa_prec)
                print("Current Best SWA Iteration: ", best_swa_iter)

            if (i % args.eval_every == 0 and i > 0) or (i == args.iters):
                with torch.no_grad():
                    prec1 = validate(args, test_loader, model, criterion, i)

                is_best = prec1 > best_prec1
                if is_best:
                    best_prec1 = prec1
                    best_iter = i

                print("Current Best Prec@1: ", best_prec1)
                print("Current Best Iteration: ", best_iter)
                print("Current cr val: {}, cr avg: {}".format(cr.val, cr.avg))

                checkpoint_path = os.path.join(
                    args.save_path,
                    'checkpoint_{:05d}_{:.2f}.pth.tar'.format(i, prec1))
                save_checkpoint(
                    {
                        'iter':
                        i,
                        'arch':
                        args.arch,
                        'state_dict':
                        model.state_dict(),
                        'best_prec1':
                        best_prec1,
                        'swa_state_dict':
                        swa_model.state_dict()
                        if args.swa_start is not None else None,
                        'swa_n':
                        swa_n if args.swa_start is not None else None,
                        'best_swa_prec':
                        best_swa_prec if args.swa_start is not None else None,
                    },
                    is_best,
                    filename=checkpoint_path)
                shutil.copyfile(
                    checkpoint_path,
                    os.path.join(args.save_path, 'checkpoint_latest'
                                 '.pth.tar'))

                if i == args.iters:
                    break
Esempio n. 3
0
def run_training(args):
    global conv_info

    cost_fw = []
    for bit in bits:
        cost_fw.append(bit/32)
    cost_fw = np.array(cost_fw) * args.weight_bits/32

    cost_eb = []
    for bit in grad_bits:
        cost_eb.append(bit/32)
    cost_eb = np.array(cost_eb) * args.weight_bits/32

    cost_gc = []
    for i in range(len(bits)):
        cost_gc.append(bits[i] * grad_bits[i]/32/32)
    cost_gc = np.array(cost_gc)

    # create model
    model = models.__dict__[args.arch](args.pretrained, proj_dim=len(bits))
    model = torch.nn.DataParallel(model).cuda()

    if args.swa_start is not None:
        print('SWA training')
        swa_model = torch.nn.DataParallel(models.__dict__[args.arch](args.pretrained, proj_dim=len(bits))).cuda()
        swa_n = 0

    else:
        print('SGD training')

    best_prec1 = 0
    best_iter = 0
    # best_full_prec = 0

    best_swa_prec = 0
    best_swa_iter = 0

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            logging.info('=> loading checkpoint `{}`'.format(args.resume))
            checkpoint = torch.load(args.resume)
            if args.proceed == 'True':
                args.start_iter = checkpoint['iter']
            else:
                args.start_iter = 0
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'],strict=True)
            logging.info('=> loaded checkpoint `{}` (iter: {})'.format(
                args.resume, checkpoint['iter']
            ))

            if args.swa_start is not None:
                swa_state_dict = checkpoint['swa_state_dict']
                if swa_state_dict is not None:
                    swa_model.load_state_dict(swa_state_dict)    
                swa_n_ckpt = checkpoint['swa_n']
                if swa_n_ckpt is not None:
                    swa_n = swa_n_ckpt
                best_swa_prec_ckpt = checkpoint['best_swa_prec']
                if best_swa_prec_ckpt is not None:
                    best_swa_prec = best_swa_prec_ckpt

        else:
            logging.info('=> no checkpoint found at `{}`'.format(args.resume))

    cudnn.benchmark = True

    train_loader = prepare_train_data(dataset=args.dataset,
                                      datadir=args.datadir,
                                      batch_size=args.batch_size,
                                      shuffle=True,
                                      num_workers=args.workers)
    test_loader = prepare_test_data(dataset=args.dataset,
                                    datadir=args.datadir,
                                    batch_size=args.batch_size,
                                    shuffle=False,
                                    num_workers=args.workers)

    if args.swa_start is not None:
        swa_loader = prepare_train_data(dataset=args.dataset,
                                      datadir=args.datadir,
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.workers)

    if args.rnn_initial:
        for param in model.parameters():
            param.requires_grad = False
        
        for param in model.control.parameters():
            param.requires_grad = True

        for param in model.control_grad.parameters():
            param.requires_grad = True
            
        for g in range(3):
            for i in range(model.num_layers[g]):
                gate_layer = getattr(model,'group{}_gate{}'.format(g + 1,i))
                for param in gate_layer.parameters():
                    param.requires_grad = True

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    skip_ratios = ListAverageMeter()
    cp_record = AverageMeter()
    cp_record_fw = AverageMeter()
    cp_record_eb = AverageMeter()
    cp_record_gc = AverageMeter()
    
    network_depth = sum(model.module.num_layers)

    if conv_info is None:
        conv_info = [1 for _ in range(network_depth)]

    layerwise_decision_statistics = []
    
    for k in range(network_depth):
        layerwise_decision_statistics.append([])
        for j in range(len(cost_fw)):
            ratio = AverageMeter()
            layerwise_decision_statistics[k].append(ratio)

    end = time.time()

    i = args.start_iter
    while i < args.iters + args.finetune_step:
        for input, target in train_loader:
            # measuring data loading time
            data_time.update(time.time() - end)

            model.train()
            # adjust_learning_rate(args, optimizer1, optimizer2, i)
            adjust_learning_rate(args, optimizer, i)
            adjust_target_ratio(args, i)
            i += 1

            target = target.cuda()
            input_var = Variable(input).cuda()
            target_var = Variable(target).cuda()
           
            if i > args.iters:
                output, _ = model(input_var, np.zeros(len(bits)), np.zeros(len(grad_bits)))
                computation_cost = 0
                cp_ratio = 1
                cp_ratio_fw = 1
                cp_ratio_eb = 1
                cp_ratio_gc = 1

            else:
                output, masks = model(input_var, bits, grad_bits)
                
                computation_cost_fw = 0
                computation_cost_eb = 0
                computation_cost_gc = 0
                
                for layer in range(network_depth):
                    
                    full_layer = reduce((lambda x, y: x * y), masks[layer][0].shape)
                    
                    for k in range(len(cost_fw)):
                        
                        dynamic_choice = masks[layer][k].sum()
                        
                        ratio = dynamic_choice / full_layer

                        layerwise_decision_statistics[layer][k].update(ratio.data, 1)
                        
                        computation_cost_fw += masks[layer][k].sum() * cost_fw[k] * conv_info[layer]
                        computation_cost_eb += masks[layer][k].sum() * cost_eb[k] * conv_info[layer]
                        computation_cost_gc += masks[layer][k].sum() * cost_gc[k] * conv_info[layer]
                
                computation_cost_fw += dws_flops_fw * args.batch_size
                computation_cost_eb += dws_flops_eb * args.batch_size
                computation_cost_gc += dws_flops_gc * args.batch_size

                computation_cost = computation_cost_fw + computation_cost_eb + computation_cost_gc

                cp_ratio_fw = float(computation_cost_fw) / args.batch_size / (sum(conv_info) + dws_flops_fw) * 100
                cp_ratio_eb = float(computation_cost_eb) / args.batch_size / (sum(conv_info) + dws_flops_eb) * 100
                cp_ratio_gc = float(computation_cost_gc) / args.batch_size / (sum(conv_info) + dws_flops_gc) * 100

                cp_ratio = float(computation_cost) / args.batch_size / (sum(conv_info)*3 + dws_flops_total) * 100
                    
                computation_loss = computation_cost / np.mean(conv_info) * args.beta
            
            if cp_ratio < args.target_ratio - args.relax:
                reg = -1
            elif cp_ratio >= args.target_ratio + args.relax:
                reg = 1
            elif cp_ratio >=args.target_ratio:
                reg = 0.1
            else:
                reg = -0.1
            
            loss_cls = criterion(output, target_var)

            if computation_loss > loss_cls/10 and args.ada_beta: 
                computation_loss *= loss_cls.detach()/10/computation_loss.detach()

            if args.computation_loss:
                loss = loss_cls + computation_loss * reg
            else:
                loss = loss_cls

            # measure accuracy and record loss
            prec1, = accuracy(output.data, target, topk=(1,))
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            # skip_ratios.update(skips, input.size(0))
            cp_record.update(cp_ratio,1)
            cp_record_fw.update(cp_ratio_fw,1)
            cp_record_eb.update(cp_ratio_eb,1)
            cp_record_gc.update(cp_ratio_gc,1)

            optimizer.zero_grad()

            if args.loss_sf:
                loss *= args.loss_sf

            loss.backward()

            if args.loss_sf:
                for param in model.parameters():
                    if param.requires_grad and param.grad is not None:
                        param.grad.data /= args.loss_sf

            optimizer.step()

            # repackage hidden units for RNN Gate
            if args.gate_type == 'rnn':
                model.module.control.repackage_hidden()

            batch_time.update(time.time() - end)
            end = time.time()

            # print log
            if i % args.print_freq == 0 or i == (args.iters - 1):
                logging.info("Iter: [{0}/{1}]\t"
                             "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                             "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                             "Loss {loss.val:.3f} ({loss.avg:.3f})\t"
                             "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t"
                             "Computation_Percentage: {cp_record.val:.3f}({cp_record.avg:.3f})\t"
                             "Computation_Percentage_FW: {cp_record_fw.val:.3f}({cp_record_fw.avg:.3f})\t"
                             "Computation_Percentage_EB: {cp_record_eb.val:.3f}({cp_record_eb.avg:.3f})\t"
                             "Computation_Percentage_GC: {cp_record_gc.val:.3f}({cp_record_gc.avg:.3f})\t".format(
                                i,
                                args.iters,
                                batch_time=batch_time,
                                data_time=data_time,
                                loss=losses,
                                top1=top1,
                                cp_record=cp_record,
                                cp_record_fw=cp_record_fw,
                                cp_record_eb=cp_record_eb,
                                cp_record_gc=cp_record_gc)
                )
            
            if args.swa_start is not None and i >= args.swa_start and i % args.swa_freq == 0:
                util_swa.moving_average(swa_model, model, 1.0 / (swa_n + 1))
                swa_n += 1
                util_swa.bn_update(swa_loader, swa_model, bits, grad_bits)

                with torch.no_grad():
                    prec1 = validate(args, test_loader, swa_model, criterion, i, swa=True)

                if prec1 > best_swa_prec:
                    best_swa_prec = prec1
                    best_swa_iter = i

                print("Current Best SWA Prec@1: ", best_swa_prec)
                print("Current Best SWA Iteration: ", best_swa_iter)

            if (i % args.eval_every == 0 and i > 0) or (i == args.iters):
                
                with torch.no_grad():
                    prec1 = validate(args, test_loader, model, criterion, i)
                    # prec_full = validate_full_prec(args, test_loader, model, criterion, i)

                is_best = prec1 > best_prec1
                if is_best:
                    best_prec1 = prec1
                    best_iter = i

                # best_full_prec = max(prec_full, best_full_prec)

                print("Current Best Prec@1: ", best_prec1)
                print("Current Best Iteration: ", best_iter)

                checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(i, prec1))
                save_checkpoint({
                    'iter': i,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'swa_state_dict' : swa_model.state_dict() if args.swa_start is not None else None,
                    'swa_n' : swa_n if args.swa_start is not None else None,
                    'best_swa_prec' : best_swa_prec if args.swa_start is not None else None,
                },
                    is_best = is_best, filename=checkpoint_path)
                shutil.copyfile(checkpoint_path, os.path.join(args.save_path,
                                                              'checkpoint_latest'
                                                              '.pth.tar'))

            if i >= args.iters + args.finetune_step:
                break