def run_training(args): # create model training_loss = 0 training_acc = 0 model = models.__dict__[args.arch](args.pretrained) model = torch.nn.DataParallel(model).cuda() if args.swa_start is not None: print('SWA training') swa_model = torch.nn.DataParallel(models.__dict__[args.arch]( args.pretrained)).cuda() swa_n = 0 else: print('SGD training') best_prec1 = 0 best_iter = 0 best_swa_prec = 0 best_swa_iter = 0 # best_full_prec = 0 if args.resume: if os.path.isfile(args.resume): logging.info('=> loading checkpoint `{}`'.format(args.resume)) checkpoint = torch.load(args.resume) args.start_iter = checkpoint['iter'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) if args.swa_start is not None: swa_state_dict = checkpoint['swa_state_dict'] if swa_state_dict is not None: swa_model.load_state_dict(swa_state_dict) swa_n_ckpt = checkpoint['swa_n'] if swa_n_ckpt is not None: swa_n = swa_n_ckpt best_swa_prec_ckpt = checkpoint['best_swa_prec'] if best_swa_prec_ckpt is not None: best_swa_prec = best_swa_prec_ckpt logging.info('=> loaded checkpoint `{}` (iter: {})'.format( args.resume, checkpoint['iter'])) else: logging.info('=> no checkpoint found at `{}`'.format(args.resume)) cudnn.benchmark = False train_loader = prepare_train_data(dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) test_loader = prepare_test_data(dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.swa_start is not None: swa_loader = prepare_train_data(dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # optimizer = torch.optim.Adam(model.parameters(), args.lr, # weight_decay=args.weight_decay) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() cr = AverageMeter() end = time.time() global scale_loss global turning_point_count global my_loss_diff_indicator i = args.start_iter while i < args.iters: for input, target in train_loader: # measuring data loading time data_time.update(time.time() - end) model.train() adjust_learning_rate(args, optimizer, i) # adjust_precision(args, i) adaptive_adjust_precision(args, turning_point_count) i += 1 fw_cost = args.num_bits * args.num_bits / 32 / 32 eb_cost = args.num_bits * args.num_grad_bits / 32 / 32 gc_cost = eb_cost cr.update((fw_cost + eb_cost + gc_cost) / 3) target = target.squeeze().long().cuda() input_var = Variable(input).cuda() target_var = Variable(target).cuda() # compute output output = model(input_var, args.num_bits, args.num_grad_bits) loss = criterion(output, target_var) training_loss += loss.item() # measure accuracy and record loss prec1, = accuracy(output.data, target, topk=(1, )) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) training_acc += prec1.item() # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print log if i % args.print_freq == 0: logging.info( "Iter: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" "Loss {loss.val:.3f} ({loss.avg:.3f})\t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t".format( i, args.iters, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1)) if args.swa_start is not None and i >= args.swa_start and i % args.swa_freq == 0: util_swa.moving_average(swa_model, model, 1.0 / (swa_n + 1)) swa_n += 1 util_swa.bn_update(swa_loader, swa_model, args.num_bits, args.num_grad_bits) prec1 = validate(args, test_loader, swa_model, criterion, i, swa=True) if prec1 > best_swa_prec: best_swa_prec = prec1 best_swa_iter = i # print("Current Best SWA Prec@1: ", best_swa_prec) # print("Current Best SWA Iteration: ", best_swa_iter) if (i % args.eval_every == 0 and i > 0) or (i == args.iters): # record training loss and test accuracy global history_score epoch = i // args.eval_every epoch_loss = training_loss / len(train_loader) with torch.no_grad(): prec1 = validate(args, test_loader, model, criterion, i) # prec_full = validate_full_prec(args, test_loader, model, criterion, i) history_score[epoch - 1][0] = epoch_loss history_score[epoch - 1][1] = np.round( training_acc / len(train_loader), 2) history_score[epoch - 1][2] = prec1 training_loss = 0 training_acc = 0 np.savetxt(os.path.join(save_path, 'record.txt'), history_score, fmt='%10.5f', delimiter=',') # apply indicator # if epoch == 1: # logging.info('initial loss value: {}'.format(epoch_loss)) # my_loss_diff_indicator.scale_loss = epoch_loss if epoch <= 10: scale_loss += epoch_loss logging.info('scale_loss at epoch {}: {}'.format( epoch, scale_loss / epoch)) my_loss_diff_indicator.scale_loss = scale_loss / epoch if turning_point_count < args.num_turning_point: my_loss_diff_indicator.get_loss(epoch_loss) flag = my_loss_diff_indicator.turning_point_emerge() if flag == True: turning_point_count += 1 logging.info( 'find {}-th turning point at {}-th epoch'.format( turning_point_count, epoch)) # print('find {}-th turning point at {}-th epoch'.format(turning_point_count, epoch)) my_loss_diff_indicator.adaptive_threshold( turning_point_count=turning_point_count) my_loss_diff_indicator.reset() logging.info( 'Epoch [{}] num_bits = {} num_grad_bits = {}'.format( epoch, args.num_bits, args.num_grad_bits)) # print statistics is_best = prec1 > best_prec1 if is_best: best_prec1 = prec1 best_iter = i # best_full_prec = max(prec_full, best_full_prec) logging.info("Current Best Prec@1: {}".format(best_prec1)) logging.info("Current Best Iteration: {}".format(best_iter)) logging.info( "Current Best SWA Prec@1: {}".format(best_swa_prec)) logging.info( "Current Best SWA Iteration: {}".format(best_swa_iter)) # print("Current Best Full Prec@1: ", best_full_prec) # checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(i, prec1)) checkpoint_path = os.path.join(args.save_path, 'ckpt.pth.tar') save_checkpoint( { 'iter': i, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'swa_state_dict': swa_model.state_dict() if args.swa_start is not None else None, 'swa_n': swa_n if args.swa_start is not None else None, 'best_swa_prec': best_swa_prec if args.swa_start is not None else None, }, is_best, filename=checkpoint_path) # shutil.copyfile(checkpoint_path, os.path.join(args.save_path, # 'checkpoint_latest' # '.pth.tar')) if i == args.iters: print("Best accuracy: " + str(best_prec1)) history_score[-1][0] = best_prec1 np.savetxt(os.path.join(save_path, 'record.txt'), history_score, fmt='%10.5f', delimiter=',') break
def run_training(args): # create model model = models.__dict__[args.arch](args.pretrained) model = torch.nn.DataParallel(model).cuda() if args.swa_start is not None: print('SWA training') swa_model = torch.nn.DataParallel(models.__dict__[args.arch]( args.pretrained)).cuda() swa_n = 0 else: print('SGD training') best_prec1 = 0 best_iter = 0 best_swa_prec = 0 best_swa_iter = 0 if args.resume: if os.path.isfile(args.resume): logging.info('=> loading checkpoint `{}`'.format(args.resume)) checkpoint = torch.load(args.resume) args.start_iter = checkpoint['iter'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) if args.swa_start is not None: swa_state_dict = checkpoint['swa_state_dict'] if swa_state_dict is not None: swa_model.load_state_dict(swa_state_dict) swa_n_ckpt = checkpoint['swa_n'] if swa_n_ckpt is not None: swa_n = swa_n_ckpt best_swa_prec_ckpt = checkpoint['best_swa_prec'] if best_swa_prec_ckpt is not None: best_swa_prec = best_swa_prec_ckpt logging.info('=> loaded checkpoint `{}` (iter: {})'.format( args.resume, checkpoint['iter'])) else: logging.info('=> no checkpoint found at `{}`'.format(args.resume)) cudnn.benchmark = False train_loader = prepare_train_data(dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) test_loader = prepare_test_data(dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.swa_start is not None: swa_loader = prepare_train_data(dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() cr = AverageMeter() end = time.time() i = args.start_iter while i < args.iters: for input, target in train_loader: # measuring data loading time data_time.update(time.time() - end) model.train() adjust_learning_rate(args, optimizer, i) cyclic_period = int(args.iters / args.num_cyclic_period) cyclic_adjust_precision(args, i, cyclic_period) i += 1 fw_cost = args.num_bits * args.num_bits / 32 / 32 eb_cost = args.num_bits * args.num_grad_bits / 32 / 32 gc_cost = eb_cost cr.update((fw_cost + eb_cost + gc_cost) / 3) target = target.squeeze().long().cuda() input_var = Variable(input).cuda() target_var = Variable(target).cuda() # compute output output = model(input_var, args.num_bits, args.num_grad_bits) loss = criterion(output, target_var) # measure accuracy and record loss prec1, = accuracy(output.data, target, topk=(1, )) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() # print log if i % args.print_freq == 0: logging.info("Num bit {}\t" "Num grad bit {}\t".format( args.num_bits, args.num_grad_bits)) logging.info( "Iter: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" "Loss {loss.val:.3f} ({loss.avg:.3f})\t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" "Training FLOPS ratio: {cr.val:.6f} ({cr.avg:.6f})\t". format(i, args.iters, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, cr=cr)) if args.swa_start is not None and i >= args.swa_start and i % args.swa_freq == 0: util_swa.moving_average(swa_model, model, 1.0 / (swa_n + 1)) swa_n += 1 util_swa.bn_update(swa_loader, swa_model, args.num_bits, args.num_grad_bits) prec1 = validate(args, test_loader, swa_model, criterion, i, swa=True) if prec1 > best_swa_prec: best_swa_prec = prec1 best_swa_iter = i print("Current Best SWA Prec@1: ", best_swa_prec) print("Current Best SWA Iteration: ", best_swa_iter) if (i % args.eval_every == 0 and i > 0) or (i == args.iters): with torch.no_grad(): prec1 = validate(args, test_loader, model, criterion, i) is_best = prec1 > best_prec1 if is_best: best_prec1 = prec1 best_iter = i print("Current Best Prec@1: ", best_prec1) print("Current Best Iteration: ", best_iter) print("Current cr val: {}, cr avg: {}".format(cr.val, cr.avg)) checkpoint_path = os.path.join( args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(i, prec1)) save_checkpoint( { 'iter': i, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'swa_state_dict': swa_model.state_dict() if args.swa_start is not None else None, 'swa_n': swa_n if args.swa_start is not None else None, 'best_swa_prec': best_swa_prec if args.swa_start is not None else None, }, is_best, filename=checkpoint_path) shutil.copyfile( checkpoint_path, os.path.join(args.save_path, 'checkpoint_latest' '.pth.tar')) if i == args.iters: break
def run_training(args): global conv_info cost_fw = [] for bit in bits: cost_fw.append(bit/32) cost_fw = np.array(cost_fw) * args.weight_bits/32 cost_eb = [] for bit in grad_bits: cost_eb.append(bit/32) cost_eb = np.array(cost_eb) * args.weight_bits/32 cost_gc = [] for i in range(len(bits)): cost_gc.append(bits[i] * grad_bits[i]/32/32) cost_gc = np.array(cost_gc) # create model model = models.__dict__[args.arch](args.pretrained, proj_dim=len(bits)) model = torch.nn.DataParallel(model).cuda() if args.swa_start is not None: print('SWA training') swa_model = torch.nn.DataParallel(models.__dict__[args.arch](args.pretrained, proj_dim=len(bits))).cuda() swa_n = 0 else: print('SGD training') best_prec1 = 0 best_iter = 0 # best_full_prec = 0 best_swa_prec = 0 best_swa_iter = 0 # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): logging.info('=> loading checkpoint `{}`'.format(args.resume)) checkpoint = torch.load(args.resume) if args.proceed == 'True': args.start_iter = checkpoint['iter'] else: args.start_iter = 0 best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict'],strict=True) logging.info('=> loaded checkpoint `{}` (iter: {})'.format( args.resume, checkpoint['iter'] )) if args.swa_start is not None: swa_state_dict = checkpoint['swa_state_dict'] if swa_state_dict is not None: swa_model.load_state_dict(swa_state_dict) swa_n_ckpt = checkpoint['swa_n'] if swa_n_ckpt is not None: swa_n = swa_n_ckpt best_swa_prec_ckpt = checkpoint['best_swa_prec'] if best_swa_prec_ckpt is not None: best_swa_prec = best_swa_prec_ckpt else: logging.info('=> no checkpoint found at `{}`'.format(args.resume)) cudnn.benchmark = True train_loader = prepare_train_data(dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) test_loader = prepare_test_data(dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.swa_start is not None: swa_loader = prepare_train_data(dataset=args.dataset, datadir=args.datadir, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) if args.rnn_initial: for param in model.parameters(): param.requires_grad = False for param in model.control.parameters(): param.requires_grad = True for param in model.control_grad.parameters(): param.requires_grad = True for g in range(3): for i in range(model.num_layers[g]): gate_layer = getattr(model,'group{}_gate{}'.format(g + 1,i)) for param in gate_layer.parameters(): param.requires_grad = True # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() skip_ratios = ListAverageMeter() cp_record = AverageMeter() cp_record_fw = AverageMeter() cp_record_eb = AverageMeter() cp_record_gc = AverageMeter() network_depth = sum(model.module.num_layers) if conv_info is None: conv_info = [1 for _ in range(network_depth)] layerwise_decision_statistics = [] for k in range(network_depth): layerwise_decision_statistics.append([]) for j in range(len(cost_fw)): ratio = AverageMeter() layerwise_decision_statistics[k].append(ratio) end = time.time() i = args.start_iter while i < args.iters + args.finetune_step: for input, target in train_loader: # measuring data loading time data_time.update(time.time() - end) model.train() # adjust_learning_rate(args, optimizer1, optimizer2, i) adjust_learning_rate(args, optimizer, i) adjust_target_ratio(args, i) i += 1 target = target.cuda() input_var = Variable(input).cuda() target_var = Variable(target).cuda() if i > args.iters: output, _ = model(input_var, np.zeros(len(bits)), np.zeros(len(grad_bits))) computation_cost = 0 cp_ratio = 1 cp_ratio_fw = 1 cp_ratio_eb = 1 cp_ratio_gc = 1 else: output, masks = model(input_var, bits, grad_bits) computation_cost_fw = 0 computation_cost_eb = 0 computation_cost_gc = 0 for layer in range(network_depth): full_layer = reduce((lambda x, y: x * y), masks[layer][0].shape) for k in range(len(cost_fw)): dynamic_choice = masks[layer][k].sum() ratio = dynamic_choice / full_layer layerwise_decision_statistics[layer][k].update(ratio.data, 1) computation_cost_fw += masks[layer][k].sum() * cost_fw[k] * conv_info[layer] computation_cost_eb += masks[layer][k].sum() * cost_eb[k] * conv_info[layer] computation_cost_gc += masks[layer][k].sum() * cost_gc[k] * conv_info[layer] computation_cost_fw += dws_flops_fw * args.batch_size computation_cost_eb += dws_flops_eb * args.batch_size computation_cost_gc += dws_flops_gc * args.batch_size computation_cost = computation_cost_fw + computation_cost_eb + computation_cost_gc cp_ratio_fw = float(computation_cost_fw) / args.batch_size / (sum(conv_info) + dws_flops_fw) * 100 cp_ratio_eb = float(computation_cost_eb) / args.batch_size / (sum(conv_info) + dws_flops_eb) * 100 cp_ratio_gc = float(computation_cost_gc) / args.batch_size / (sum(conv_info) + dws_flops_gc) * 100 cp_ratio = float(computation_cost) / args.batch_size / (sum(conv_info)*3 + dws_flops_total) * 100 computation_loss = computation_cost / np.mean(conv_info) * args.beta if cp_ratio < args.target_ratio - args.relax: reg = -1 elif cp_ratio >= args.target_ratio + args.relax: reg = 1 elif cp_ratio >=args.target_ratio: reg = 0.1 else: reg = -0.1 loss_cls = criterion(output, target_var) if computation_loss > loss_cls/10 and args.ada_beta: computation_loss *= loss_cls.detach()/10/computation_loss.detach() if args.computation_loss: loss = loss_cls + computation_loss * reg else: loss = loss_cls # measure accuracy and record loss prec1, = accuracy(output.data, target, topk=(1,)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) # skip_ratios.update(skips, input.size(0)) cp_record.update(cp_ratio,1) cp_record_fw.update(cp_ratio_fw,1) cp_record_eb.update(cp_ratio_eb,1) cp_record_gc.update(cp_ratio_gc,1) optimizer.zero_grad() if args.loss_sf: loss *= args.loss_sf loss.backward() if args.loss_sf: for param in model.parameters(): if param.requires_grad and param.grad is not None: param.grad.data /= args.loss_sf optimizer.step() # repackage hidden units for RNN Gate if args.gate_type == 'rnn': model.module.control.repackage_hidden() batch_time.update(time.time() - end) end = time.time() # print log if i % args.print_freq == 0 or i == (args.iters - 1): logging.info("Iter: [{0}/{1}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" "Loss {loss.val:.3f} ({loss.avg:.3f})\t" "Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t" "Computation_Percentage: {cp_record.val:.3f}({cp_record.avg:.3f})\t" "Computation_Percentage_FW: {cp_record_fw.val:.3f}({cp_record_fw.avg:.3f})\t" "Computation_Percentage_EB: {cp_record_eb.val:.3f}({cp_record_eb.avg:.3f})\t" "Computation_Percentage_GC: {cp_record_gc.val:.3f}({cp_record_gc.avg:.3f})\t".format( i, args.iters, batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, cp_record=cp_record, cp_record_fw=cp_record_fw, cp_record_eb=cp_record_eb, cp_record_gc=cp_record_gc) ) if args.swa_start is not None and i >= args.swa_start and i % args.swa_freq == 0: util_swa.moving_average(swa_model, model, 1.0 / (swa_n + 1)) swa_n += 1 util_swa.bn_update(swa_loader, swa_model, bits, grad_bits) with torch.no_grad(): prec1 = validate(args, test_loader, swa_model, criterion, i, swa=True) if prec1 > best_swa_prec: best_swa_prec = prec1 best_swa_iter = i print("Current Best SWA Prec@1: ", best_swa_prec) print("Current Best SWA Iteration: ", best_swa_iter) if (i % args.eval_every == 0 and i > 0) or (i == args.iters): with torch.no_grad(): prec1 = validate(args, test_loader, model, criterion, i) # prec_full = validate_full_prec(args, test_loader, model, criterion, i) is_best = prec1 > best_prec1 if is_best: best_prec1 = prec1 best_iter = i # best_full_prec = max(prec_full, best_full_prec) print("Current Best Prec@1: ", best_prec1) print("Current Best Iteration: ", best_iter) checkpoint_path = os.path.join(args.save_path, 'checkpoint_{:05d}_{:.2f}.pth.tar'.format(i, prec1)) save_checkpoint({ 'iter': i, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, 'swa_state_dict' : swa_model.state_dict() if args.swa_start is not None else None, 'swa_n' : swa_n if args.swa_start is not None else None, 'best_swa_prec' : best_swa_prec if args.swa_start is not None else None, }, is_best = is_best, filename=checkpoint_path) shutil.copyfile(checkpoint_path, os.path.join(args.save_path, 'checkpoint_latest' '.pth.tar')) if i >= args.iters + args.finetune_step: break