コード例 #1
0
def trian_validate_with_scheduling(args,
                                   net,
                                   criterion,
                                   optimizer,
                                   compress_scheduler,
                                   device,
                                   epoch=1,
                                   validate=True,
                                   verbose=True):
    # Whtat's collectors_context
    if compress_scheduler:
        compress_scheduler.on_epoch_begin(epoch)

    top1, top5, loss = light_train_with_distiller(net, criterion, optimizer,
                                                  compress_scheduler, device,
                                                  epoch)

    if validate:
        """
        top1, loss = _validate(net, criterion, optimizer, lr_scheduler, compress_scheduler,
                                device, epoch) # remove top5 accuracy.
        """
        top1, top5, loss = _validate('val', net, criterion, device)
    #print(summary.masks_sparsity_tbl_summary(net, compress_scheduler))
    t, total = summary.weights_sparsity_tbl_summary(net,
                                                    return_total_sparsity=True)
    print("\nParameters:\n" + str(t))
    print('Total sparsity: {:0.2f}\n'.format(total))

    if compress_scheduler:
        compress_scheduler.on_epoch_end(epoch,
                                        optimizer,
                                        metrics={
                                            'min': loss,
                                            'max': top1
                                        })

    # Build performance tracker object whilst saving it.
    tracker = pt.SparsityAccuracyTracker(args.num_best_scores)
    tracker.step(net, epoch, top1=top1, top5=top5)  #, top5=top5)
    best_score = tracker.best_scores()[0]
    is_best = epoch == best_score.epoch
    checkpoint_extras = {
        'current_top1': top1,
        'best_top1': best_score.top1,
        'best_epoch': best_score.epoch
    }

    # args.arch = Architecture name
    ckpt.save_checkpoint(args.epoch,
                         args.arch,
                         net,
                         optimizer=optimizer,
                         scheduler=compress_scheduler,
                         extras=checkpoint_extras,
                         is_best=is_best,
                         name=args.name,
                         dir=args.model_path)
    return top1, top5, loss, tracker
コード例 #2
0
ファイル: main_yolo.py プロジェクト: bwtseng/Object-Detection
                #    compress_scheduler.retrain_phase = True
            #else:
            #    lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.7)
            model.to(device)
            print("\nStart Training")

        # ***************************************************
        # Print the initial sparsity of this model, and please check whether the pruning
        # weight name is correct or not.
        # ***************************************************
        t, total = summary.weights_sparsity_tbl_summary(
            model, return_total_sparsity=True)
        print("\nParameters Table: {}".format(str(t)))
        print("\nSparsity: {}.".format(total))

        tracker = pt.SparsityAccuracyTracker(args.num_best_scores)
        tracker.reset()
        if args.apex:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

        for epoch in range(args.epoch):
            print("\n")
            #print(compress_scheduler.policies[200])
            nat_loss, tracker = train.trian_validate_with_scheduling(
                args,
                model,
                optimizer,
                compress_scheduler,
                device,
                dataloaders,
                dataset_sizes,