def trian_validate_with_scheduling(args, net, criterion, optimizer, compress_scheduler, device, epoch=1, validate=True, verbose=True): # Whtat's collectors_context if compress_scheduler: compress_scheduler.on_epoch_begin(epoch) top1, top5, loss = light_train_with_distiller(net, criterion, optimizer, compress_scheduler, device, epoch) if validate: """ top1, loss = _validate(net, criterion, optimizer, lr_scheduler, compress_scheduler, device, epoch) # remove top5 accuracy. """ top1, top5, loss = _validate('val', net, criterion, device) #print(summary.masks_sparsity_tbl_summary(net, compress_scheduler)) t, total = summary.weights_sparsity_tbl_summary(net, return_total_sparsity=True) print("\nParameters:\n" + str(t)) print('Total sparsity: {:0.2f}\n'.format(total)) if compress_scheduler: compress_scheduler.on_epoch_end(epoch, optimizer, metrics={ 'min': loss, 'max': top1 }) # Build performance tracker object whilst saving it. tracker = pt.SparsityAccuracyTracker(args.num_best_scores) tracker.step(net, epoch, top1=top1, top5=top5) #, top5=top5) best_score = tracker.best_scores()[0] is_best = epoch == best_score.epoch checkpoint_extras = { 'current_top1': top1, 'best_top1': best_score.top1, 'best_epoch': best_score.epoch } # args.arch = Architecture name ckpt.save_checkpoint(args.epoch, args.arch, net, optimizer=optimizer, scheduler=compress_scheduler, extras=checkpoint_extras, is_best=is_best, name=args.name, dir=args.model_path) return top1, top5, loss, tracker
# compress_scheduler.retrain_phase = True #else: # lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7) model.to(device) print("\nStart Training") # *************************************************** # Print the initial sparsity of this model, and please check whether the pruning # weight name is correct or not. # *************************************************** t, total = summary.weights_sparsity_tbl_summary( model, return_total_sparsity=True) print("\nParameters Table: {}".format(str(t))) print("\nSparsity: {}.".format(total)) tracker = pt.SparsityAccuracyTracker(args.num_best_scores) tracker.reset() if args.apex: model, optimizer = amp.initialize(model, optimizer, opt_level="O1") for epoch in range(args.epoch): print("\n") #print(compress_scheduler.policies[200]) nat_loss, tracker = train.trian_validate_with_scheduling( args, model, optimizer, compress_scheduler, device, dataloaders, dataset_sizes,