Esempio n. 1
0
def main_worker(args):
    train, validate, modifier = get_trainer(args)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    model = get_model(args)
    model = set_gpu(args, model)
    wandb.watch(model)

    if args.pretrained:
        pretrained(args, model)

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if args.resume:
        best_acc1 = resume(args, model, optimizer)

    # Data loading code
    if args.evaluate:
        acc1, acc5 = validate(data.val_loader,
                              model,
                              criterion,
                              args,
                              writer=None,
                              epoch=args.start_epoch)

        return

    # Set up directories
    run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)
    args.ckpt_base_dir = ckpt_base_dir

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(1,
                                     [epoch_time, validation_time, train_time],
                                     prefix="Overall Timing")

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    save_checkpoint(
        {
            "epoch": 0,
            "arch": args.arch,
            "state_dict": model.state_dict(),
            "best_acc1": best_acc1,
            "best_acc5": best_acc5,
            "best_train_acc1": best_train_acc1,
            "best_train_acc5": best_train_acc5,
            "optimizer": optimizer.state_dict(),
            "curr_acc1": acc1 if acc1 else "Not evaluated",
        },
        False,
        filename=ckpt_base_dir / f"initial.state",
        save=False,
    )

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        lr_policy(epoch, iteration=None)
        modifier(args, epoch, model)

        cur_lr = get_lr(optimizer)

        # train for one epoch
        start_train = time.time()
        train_acc1, train_acc5 = train(data.train_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       args,
                                       writer=writer)
        train_time.update((time.time() - start_train) / 60)

        # evaluate on validation set
        start_validation = time.time()
        acc1, acc5 = validate(data.val_loader, model, criterion, args, writer,
                              epoch)
        validation_time.update((time.time() - start_validation) / 60)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)

        save = ((epoch % args.save_every) == 0) and args.save_every > 0
        if is_best or save or epoch == args.epochs - 1:
            if is_best:
                print(
                    f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}"
                )

            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc1": best_acc1,
                    "best_acc5": best_acc5,
                    "best_train_acc1": best_train_acc1,
                    "best_train_acc5": best_train_acc5,
                    "optimizer": optimizer.state_dict(),
                    "curr_acc1": acc1,
                    "curr_acc5": acc5,
                },
                is_best,
                filename=ckpt_base_dir / f"epoch_{epoch}.state",
                save=save,
            )
            wandb.log({
                "curr_acc1": acc1,
                "curr_acc5": acc5,
            })

        epoch_time.update((time.time() - end_epoch) / 60)
        progress_overall.display(epoch)
        progress_overall.write_to_tensorboard(writer,
                                              prefix="diagnostics",
                                              global_step=epoch)

        if args.conv_type == "SampleSubnetConv":
            count = 0
            sum_pr = 0.0
            for n, m in model.named_modules():
                if isinstance(m, SampleSubnetConv):
                    # avg pr across 10 samples
                    pr = 0.0
                    for _ in range(10):
                        pr += ((torch.rand_like(m.clamped_scores) >=
                                m.clamped_scores).float().mean().item())
                    pr /= 10.0
                    writer.add_scalar("pr/{}".format(n), pr, epoch)
                    sum_pr += pr
                    count += 1

            args.prune_rate = sum_pr / count
            writer.add_scalar("pr/average", args.prune_rate, epoch)

        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

    write_result_to_csv(
        best_acc1=best_acc1,
        best_acc5=best_acc5,
        best_train_acc1=best_train_acc1,
        best_train_acc5=best_train_acc5,
        prune_rate=args.prune_rate,
        curr_acc1=acc1,
        curr_acc5=acc5,
        base_config=args.config,
        name=args.name,
    )
Esempio n. 2
0
File: main.py Progetto: zj15001/STR
def main_worker(args):
    args.gpu = None

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    model = get_model(args)
    model = set_gpu(args, model)

    # Set up directories
    run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)

    # Loading pretrained model
    if args.pretrained:
        pretrained(args, model)

        # Saving a DenseConv (nn.Conv2d) compatible model
        if args.dense_conv_model:
            print(
                f"==> DenseConv compatible model, saving at {ckpt_base_dir / 'model_best.pth'}"
            )
            save_checkpoint(
                {
                    "epoch": 0,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                },
                True,
                filename=ckpt_base_dir / f"epoch_pretrained.state",
                save=True,
            )
            return

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if args.resume:
        best_acc1 = resume(args, model, optimizer)

    # Evaulation of a model
    if args.evaluate:
        acc1, acc5 = validate(data.val_loader,
                              model,
                              criterion,
                              args,
                              writer=None,
                              epoch=args.start_epoch)
        return

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(1,
                                     [epoch_time, validation_time, train_time],
                                     prefix="Overall Timing")

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    save_checkpoint(
        {
            "epoch": 0,
            "arch": args.arch,
            "state_dict": model.state_dict(),
            "best_acc1": best_acc1,
            "best_acc5": best_acc5,
            "best_train_acc1": best_train_acc1,
            "best_train_acc5": best_train_acc5,
            "optimizer": optimizer.state_dict(),
            "curr_acc1": acc1 if acc1 else "Not evaluated",
        },
        False,
        filename=ckpt_base_dir / f"initial.state",
        save=False,
    )

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        lr_policy(epoch, iteration=None)
        cur_lr = get_lr(optimizer)

        # Gradual pruning in GMP experiments
        if args.conv_type == "GMPConv" and epoch >= args.init_prune_epoch and epoch <= args.final_prune_epoch:
            total_prune_epochs = args.final_prune_epoch - args.init_prune_epoch + 1
            for n, m in model.named_modules():
                if hasattr(m, 'set_curr_prune_rate'):
                    prune_decay = (
                        1 - ((args.curr_prune_epoch - args.init_prune_epoch) /
                             total_prune_epochs))**3
                    curr_prune_rate = m.prune_rate - (m.prune_rate *
                                                      prune_decay)
                    m.set_curr_prune_rate(curr_prune_rate)

        # train for one epoch
        start_train = time.time()
        train_acc1, train_acc5 = train(data.train_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       args,
                                       writer=writer)
        train_time.update((time.time() - start_train) / 60)

        # evaluate on validation set
        start_validation = time.time()
        acc1, acc5 = validate(data.val_loader, model, criterion, args, writer,
                              epoch)
        validation_time.update((time.time() - start_validation) / 60)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)

        save = ((epoch % args.save_every) == 0) and args.save_every > 0
        if is_best or save or epoch == args.epochs - 1:
            if is_best:
                print(
                    f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}"
                )

            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc1": best_acc1,
                    "best_acc5": best_acc5,
                    "best_train_acc1": best_train_acc1,
                    "best_train_acc5": best_train_acc5,
                    "optimizer": optimizer.state_dict(),
                    "curr_acc1": acc1,
                    "curr_acc5": acc5,
                },
                is_best,
                filename=ckpt_base_dir / f"epoch_{epoch}.state",
                save=save,
            )

        epoch_time.update((time.time() - end_epoch) / 60)
        progress_overall.display(epoch)
        progress_overall.write_to_tensorboard(writer,
                                              prefix="diagnostics",
                                              global_step=epoch)

        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

        # Storing sparsity and threshold statistics for STRConv models
        if args.conv_type == "STRConv":
            count = 0
            sum_sparse = 0.0
            for n, m in model.named_modules():
                if isinstance(m, STRConv):
                    sparsity, total_params, thresh = m.getSparsity()
                    writer.add_scalar("sparsity/{}".format(n), sparsity, epoch)
                    writer.add_scalar("thresh/{}".format(n), thresh, epoch)
                    sum_sparse += int(((100 - sparsity) / 100) * total_params)
                    count += total_params
            total_sparsity = 100 - (100 * sum_sparse / count)
            writer.add_scalar("sparsity/total", total_sparsity, epoch)
        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

    write_result_to_csv(
        best_acc1=best_acc1,
        best_acc5=best_acc5,
        best_train_acc1=best_train_acc1,
        best_train_acc5=best_train_acc5,
        prune_rate=args.prune_rate,
        curr_acc1=acc1,
        curr_acc5=acc5,
        base_config=args.config,
        name=args.name,
    )
    if args.conv_type == "STRConv":
        json_data = {}
        json_thres = {}
        for n, m in model.named_modules():
            if isinstance(m, STRConv):
                sparsity = m.getSparsity()
                json_data[n] = sparsity[0]
                sum_sparse += int(((100 - sparsity[0]) / 100) * sparsity[1])
                count += sparsity[1]
                json_thres[n] = sparsity[2]
        json_data["total"] = 100 - (100 * sum_sparse / count)
        if not os.path.exists("runs/layerwise_sparsity"):
            os.mkdir("runs/layerwise_sparsity")
        if not os.path.exists("runs/layerwise_threshold"):
            os.mkdir("runs/layerwise_threshold")
        with open("runs/layerwise_sparsity/{}.json".format(args.name),
                  "w") as f:
            json.dump(json_data, f)
        with open("runs/layerwise_threshold/{}.json".format(args.name),
                  "w") as f:
            json.dump(json_thres, f)
Esempio n. 3
0
def main_worker(args):
    # NEW: equivalent to MPI init.
    print("world size ", os.environ['OMPI_COMM_WORLD_SIZE'])
    print("rank ", os.environ['OMPI_COMM_WORLD_RANK'])
    torch.distributed.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=int(os.environ['OMPI_COMM_WORLD_SIZE']),
        rank=int(os.environ['OMPI_COMM_WORLD_RANK']))

    # NEW: lookup number of ranks in the job, and our rank
    args.world_size = torch.distributed.get_world_size()
    print("world size ", args.world_size)
    args.rank = torch.distributed.get_rank()
    print("rank ", args.rank)
    ngpus_per_node = torch.cuda.device_count()
    print("ngpus_per_node ", ngpus_per_node)
    local_rank = args.rank % ngpus_per_node
    print("local_rank ", local_rank)

    # NEW: Globalize variables
    global best_acc1
    global best_acc5
    global best_train_acc1
    global best_train_acc5

    #args.gpu = None
    # NEW: Specify gpu
    args.gpu = local_rank
    train, validate, modifier = get_trainer(args)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    model = get_model(args)

    # NEW: Distributed data
    #if args.distributed:
    args.batch_size = int(args.batch_size / ngpus_per_node)
    args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

    #model = set_gpu(args, model)
    # NEW: Modified function for loading gpus on multinode setups
    model = lassen_set_gpu(args, model)

    if args.pretrained:
        pretrained(args, model)

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        #criterion = nn.CrossEntropyLoss().cuda()
        # NEW: Specify gpu
        criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if args.resume:
        best_acc1 = resume(args, model, optimizer)

    # Data loading code
    if args.evaluate:
        acc1, acc5 = validate(data.val_loader,
                              model,
                              criterion,
                              args,
                              writer=None,
                              epoch=args.start_epoch)

        return

    # Set up directories
    # NEW: Only do for main processor (one with global rank 0)
    if args.rank == 0:
        run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)
        args.ckpt_base_dir = ckpt_base_dir

    # NEW: Only do for main processor (one with global rank 0)
    if args.rank == 0:
        writer = SummaryWriter(log_dir=log_base_dir)
    else:
        writer = None

    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)

    # NEW: Only do for main processor (one with global rank 0)
    if args.rank == 0:
        progress_overall = ProgressMeter(
            1, [epoch_time, validation_time, train_time],
            prefix="Overall Timing")

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    # NEW: Only do for main processor (one with global rank 0)
    if args.rank == 0:
        save_checkpoint(
            {
                "epoch": 0,
                "arch": args.arch,
                "state_dict": model.state_dict(),
                "best_acc1": best_acc1,
                "best_acc5": best_acc5,
                "best_train_acc1": best_train_acc1,
                "best_train_acc5": best_train_acc5,
                "optimizer": optimizer.state_dict(),
                "curr_acc1": acc1 if acc1 else "Not evaluated",
            },
            False,
            filename=ckpt_base_dir / f"initial.state",
            save=False,
        )

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        # NEW: Distributed data
        #if args.distributed:
        data.train_sampler.set_epoch(epoch)
        data.val_sampler.set_epoch(epoch)

        lr_policy(epoch, iteration=None)
        #modifier(args, epoch, model)

        cur_lr = get_lr(optimizer)

        # train for one epoch
        start_train = time.time()
        train_acc1, train_acc5 = train(data.train_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       args,
                                       writer=writer)
        #train_acc1, train_acc5 = train(
        #    data.train_loader, model, criterion, optimizer, epoch, args, writer=None
        #)
        train_time.update((time.time() - start_train) / 60)

        # evaluate on validation set
        start_validation = time.time()

        # NEW: Only write values to tensorboard for main processor (one with global rank 0)
        if args.rank == 0:
            acc1, acc5 = validate(data.val_loader, model, criterion, args,
                                  writer, epoch)
        else:
            acc1, acc5 = validate(data.val_loader, model, criterion, args,
                                  None, epoch)

        validation_time.update((time.time() - start_validation) / 60)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)

        save = ((epoch % args.save_every) == 0) and args.save_every > 0

        # NEW: Only do for main processor (one with global rank 0)
        if args.rank == 0:
            if is_best or save or epoch == args.epochs - 1:
                if is_best:
                    print(
                        f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}"
                    )

                save_checkpoint(
                    {
                        "epoch": epoch + 1,
                        "arch": args.arch,
                        "state_dict": model.state_dict(),
                        "best_acc1": best_acc1,
                        "best_acc5": best_acc5,
                        "best_train_acc1": best_train_acc1,
                        "best_train_acc5": best_train_acc5,
                        "optimizer": optimizer.state_dict(),
                        "curr_acc1": acc1,
                        "curr_acc5": acc5,
                    },
                    is_best,
                    filename=ckpt_base_dir / f"epoch_most_recent.state",
                    save=save,
                )
                #filename=ckpt_base_dir / f"epoch_{epoch}.state",

        epoch_time.update((time.time() - end_epoch) / 60)

        # NEW: Only do for main processor (one with global rank 0)
        if args.rank == 0:
            progress_overall.display(epoch)
            progress_overall.write_to_tensorboard(writer,
                                                  prefix="diagnostics",
                                                  global_step=epoch)

            if args.conv_type == "SampleSubnetConv":
                count = 0
                sum_pr = 0.0
                for n, m in model.named_modules():
                    if isinstance(m, SampleSubnetConv):
                        # avg pr across 10 samples
                        pr = 0.0
                        for _ in range(10):
                            pr += ((torch.rand_like(m.clamped_scores) >=
                                    m.clamped_scores).float().mean().item())
                        pr /= 10.0
                        writer.add_scalar("pr/{}".format(n), pr, epoch)
                        sum_pr += pr
                        count += 1

                args.prune_rate = sum_pr / count
                writer.add_scalar("pr/average", args.prune_rate, epoch)

        # NEW: Only do for main processor (one with global rank 0)
        if args.rank == 0:
            writer.add_scalar("test/lr", cur_lr, epoch)

        end_epoch = time.time()

    # NEW: Only do for main processor (one with global rank 0)
    if args.rank == 0:
        write_result_to_csv(
            best_acc1=best_acc1,
            best_acc5=best_acc5,
            best_train_acc1=best_train_acc1,
            best_train_acc5=best_train_acc5,
            prune_rate=args.prune_rate,
            curr_acc1=acc1,
            curr_acc5=acc5,
            base_config=args.config,
            name=args.name,
        )
Esempio n. 4
0
def main_worker(args):
    args.gpu = None
    train, validate, modifier = get_trainer(args)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    model = get_model(args)
    model = set_gpu(args, model)

    if args.pretrained:
        pretrained(args, model)

    # SJT modification:
    if args.exp_mode:  # pretraining/pruning/funetuning
        exp_mode = args.exp_mode
        if exp_mode == "pretraining":
            # YHT modefication, setting the pruning rate to 0
            print(
                "Figure out your exp_mode is pretraining, setting prune-rate to 0"
            )
            args.prune_rate = 0
            unfreeze_model_weights(model)
            freeze_model_subnet(model)
    # End of SJT modification

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if args.resume:
        # SJT modification
        if args.exp_mode:
            if args.exp_mode == "pruning":
                optimizer = resume_pruning(args, model)
            else:  # Only can be "finetuning"
                if args.exp_mode != "finetuning":
                    print(
                        "resume method should be combined with pruning/finetuning exp_mode together!"
                    )
                    return
                else:
                    optimizer = resume_finetuning(args, model)
                    # YHT: not sure whether it is needed
                    #lr_policy = get_policy(args.lr_policy)(optimizer, args)
                    #print("#####################DEBUG PRINT : VALIDATE FIRST#####################")
                    #validate(data.val_loader, model, criterion, args, writer= None, epoch=args.start_epoch)
        else:
            best_acc1 = resume(args, model, optimizer)
        # End of SJT modification
    else:
        # YHT modification
        if args.exp_mode:
            if args.exp_mode == "finetuning":
                #here, we suppose the user want to use init prun-rate vector to do the finetuning(subnetwork)
                print(
                    "Using finetuning mode without resume, which is supposed to be innit fientune."
                )
                optimizer = resume_finetuning(args, model)
                # YHT: not sure whether it is needed
                lr_policy = get_policy(args.lr_policy)(optimizer, args)
        # End of modification

    # Data loading code
    if args.evaluate:
        acc1, acc5 = validate(data.val_loader,
                              model,
                              criterion,
                              args,
                              writer=None,
                              epoch=args.start_epoch)

        return

    # Set up directories
    run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)
    args.ckpt_base_dir = ckpt_base_dir

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(1,
                                     [epoch_time, validation_time, train_time],
                                     prefix="Overall Timing")

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    save_checkpoint(
        {
            "epoch": 0,
            "arch": args.arch,
            "state_dict": model.state_dict(),
            "best_acc1": best_acc1,
            "best_acc5": best_acc5,
            "best_train_acc1": best_train_acc1,
            "best_train_acc5": best_train_acc5,
            "optimizer": optimizer.state_dict(),
            "curr_acc1": acc1 if acc1 else "Not evaluated",
        },
        False,
        filename=ckpt_base_dir / f"initial.state",
        save=False,
    )

    if args.gp_warm_up:
        record_prune_rate = args.prune_rate
    if args.print_more:
        print_global_layerwise_prune_rate(model, args.prune_rate)

    # YHT modification May 20
    # till here, we have every prune-rate is accurate
    # Now we need to create mask if prandom is true using
    if args.prandom:
        make_prandom_mask(model)
    # End of modification

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        lr_policy(epoch, iteration=None)
        modifier(args, epoch, model)
        cur_lr = get_lr(optimizer)
        if args.print_more:
            print("In epoch{epoch}, lr = {cur_lr}")
        # train for one epoch
        start_train = time.time()
        # WHN modeification add global pruning
        if args.pscale == "global":
            if args.gp_warm_up:
                if epoch < args.gp_warm_up_epochs:
                    args.prune_rate = 0
                else:
                    args.prune_rate = record_prune_rate
            if not args.prandom:
                args.score_threshold = get_global_score_threshold(
                    model, args.prune_rate)

        # YHT modification
        if args.print_more:
            print_global_layerwise_prune_rate(model, args.prune_rate)
        # End of modification

        train_acc1, train_acc5 = train(data.train_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       args,
                                       writer=writer)
        train_time.update((time.time() - start_train) / 60)

        # evaluate on validation set
        start_validation = time.time()
        # if random labeled, evaluate on training set (by yty)
        if args.shuffle:
            acc1, acc5 = train_acc1, train_acc5
        else:
            acc1, acc5 = validate(data.val_loader, model, criterion, args,
                                  writer, epoch)

        validation_time.update((time.time() - start_validation) / 60)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)

        save = ((epoch % args.save_every) == 0) and args.save_every > 0
        if is_best or save or epoch == args.epochs - 1:
            if is_best:
                print(
                    f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}"
                )

            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc1": best_acc1,
                    "best_acc5": best_acc5,
                    "best_train_acc1": best_train_acc1,
                    "best_train_acc5": best_train_acc5,
                    "optimizer": optimizer.state_dict(),
                    "curr_acc1": acc1,
                    "curr_acc5": acc5,
                },
                is_best,
                filename=ckpt_base_dir / f"epoch_{epoch}.state",
                save=save,
            )

        epoch_time.update((time.time() - end_epoch) / 60)
        progress_overall.display(epoch)
        progress_overall.write_to_tensorboard(writer,
                                              prefix="diagnostics",
                                              global_step=epoch)

        if args.conv_type == "SampleSubnetConv":
            count = 0
            sum_pr = 0.0
            for n, m in model.named_modules():
                if isinstance(m, SampleSubnetConv):
                    # avg pr across 10 samples
                    pr = 0.0
                    for _ in range(10):
                        pr += ((torch.rand_like(m.clamped_scores) >=
                                m.clamped_scores).float().mean().item())
                    pr /= 10.0
                    writer.add_scalar("pr/{}".format(n), pr, epoch)
                    sum_pr += pr
                    count += 1

            args.prune_rate = sum_pr / count
            writer.add_scalar("pr/average", args.prune_rate, epoch)

        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

    write_result_to_csv(
        best_acc1=best_acc1,
        best_acc5=best_acc5,
        best_train_acc1=best_train_acc1,
        best_train_acc5=best_train_acc5,
        prune_rate=args.prune_rate,
        curr_acc1=acc1,
        curr_acc5=acc5,
        base_config=args.config,
        name=args.name,
    )