Exemplo n.º 1
0
def main_worker(args):
    train, validate, modifier = get_trainer(args)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    model = get_model(args)
    model = set_gpu(args, model)
    wandb.watch(model)

    if args.pretrained:
        pretrained(args, model)

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if args.resume:
        best_acc1 = resume(args, model, optimizer)

    # Data loading code
    if args.evaluate:
        acc1, acc5 = validate(data.val_loader,
                              model,
                              criterion,
                              args,
                              writer=None,
                              epoch=args.start_epoch)

        return

    # Set up directories
    run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)
    args.ckpt_base_dir = ckpt_base_dir

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(1,
                                     [epoch_time, validation_time, train_time],
                                     prefix="Overall Timing")

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    save_checkpoint(
        {
            "epoch": 0,
            "arch": args.arch,
            "state_dict": model.state_dict(),
            "best_acc1": best_acc1,
            "best_acc5": best_acc5,
            "best_train_acc1": best_train_acc1,
            "best_train_acc5": best_train_acc5,
            "optimizer": optimizer.state_dict(),
            "curr_acc1": acc1 if acc1 else "Not evaluated",
        },
        False,
        filename=ckpt_base_dir / f"initial.state",
        save=False,
    )

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        lr_policy(epoch, iteration=None)
        modifier(args, epoch, model)

        cur_lr = get_lr(optimizer)

        # train for one epoch
        start_train = time.time()
        train_acc1, train_acc5 = train(data.train_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       args,
                                       writer=writer)
        train_time.update((time.time() - start_train) / 60)

        # evaluate on validation set
        start_validation = time.time()
        acc1, acc5 = validate(data.val_loader, model, criterion, args, writer,
                              epoch)
        validation_time.update((time.time() - start_validation) / 60)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)

        save = ((epoch % args.save_every) == 0) and args.save_every > 0
        if is_best or save or epoch == args.epochs - 1:
            if is_best:
                print(
                    f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}"
                )

            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc1": best_acc1,
                    "best_acc5": best_acc5,
                    "best_train_acc1": best_train_acc1,
                    "best_train_acc5": best_train_acc5,
                    "optimizer": optimizer.state_dict(),
                    "curr_acc1": acc1,
                    "curr_acc5": acc5,
                },
                is_best,
                filename=ckpt_base_dir / f"epoch_{epoch}.state",
                save=save,
            )
            wandb.log({
                "curr_acc1": acc1,
                "curr_acc5": acc5,
            })

        epoch_time.update((time.time() - end_epoch) / 60)
        progress_overall.display(epoch)
        progress_overall.write_to_tensorboard(writer,
                                              prefix="diagnostics",
                                              global_step=epoch)

        if args.conv_type == "SampleSubnetConv":
            count = 0
            sum_pr = 0.0
            for n, m in model.named_modules():
                if isinstance(m, SampleSubnetConv):
                    # avg pr across 10 samples
                    pr = 0.0
                    for _ in range(10):
                        pr += ((torch.rand_like(m.clamped_scores) >=
                                m.clamped_scores).float().mean().item())
                    pr /= 10.0
                    writer.add_scalar("pr/{}".format(n), pr, epoch)
                    sum_pr += pr
                    count += 1

            args.prune_rate = sum_pr / count
            writer.add_scalar("pr/average", args.prune_rate, epoch)

        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

    write_result_to_csv(
        best_acc1=best_acc1,
        best_acc5=best_acc5,
        best_train_acc1=best_train_acc1,
        best_train_acc5=best_train_acc5,
        prune_rate=args.prune_rate,
        curr_acc1=acc1,
        curr_acc5=acc5,
        base_config=args.config,
        name=args.name,
    )
Exemplo n.º 2
0
def trn(cfg, model):

    cfg.logger.info(cfg)
    if cfg.seed is not None:
        random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)

    train, validate = get_trainer(cfg)

    if cfg.gpu is not None:
        cfg.logger.info("Use GPU: {} for training".format(cfg.gpu))

    linear_classifier_layer = model.module[1]
    optimizer = get_optimizer(cfg, linear_classifier_layer)
    cfg.logger.info(f"=> Getting {cfg.set} dataset")

    dataset = getattr(data, cfg.set)(cfg)

    lr_policy = get_policy(cfg.lr_policy)(optimizer, cfg)

    softmax_criterion = nn.CrossEntropyLoss().cuda()

    criterion = lambda output, target: softmax_criterion(output, target)

    # optionally resume from a checkpoint
    best_val_acc1 = 0.0
    best_val_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if cfg.resume:
        best_val_acc1 = resume(cfg, model, optimizer)

    run_base_dir, ckpt_base_dir, log_base_dir = path_utils.get_directories(
        cfg, cfg.gpu)
    cfg.ckpt_base_dir = ckpt_base_dir

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(1,
                                     [epoch_time, validation_time, train_time],
                                     cfg,
                                     prefix="Overall Timing")

    end_epoch = time.time()
    cfg.start_epoch = cfg.start_epoch or 0
    last_val_acc1 = None

    start_time = time.time()
    gpu_info = gpu_utils.GPU_Utils(gpu_index=cfg.gpu)

    # Start training
    for epoch in range(cfg.start_epoch, cfg.epochs):
        cfg.logger.info('Model conv 1 {} at epoch {}'.format(
            torch.sum(model.module[0].conv1.weight),
            epoch))  ##  make sure backbone is not updated
        if cfg.world_size > 1:
            dataset.sampler.set_epoch(epoch)
        lr_policy(epoch, iteration=None)

        cur_lr = net_utils.get_lr(optimizer)

        start_train = time.time()
        train_acc1, train_acc5 = train(dataset.trn_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       cfg,
                                       writer=writer)
        train_time.update((time.time() - start_train) / 60)

        if (epoch + 1) % cfg.test_interval == 0:
            if cfg.gpu == cfg.base_gpu:
                # evaluate on validation set
                start_validation = time.time()
                last_val_acc1, last_val_acc5 = validate(
                    dataset.val_loader, model.module, criterion, cfg, writer,
                    epoch)
                validation_time.update((time.time() - start_validation) / 60)

                # remember best acc@1 and save checkpoint
                is_best = last_val_acc1 > best_val_acc1
                best_val_acc1 = max(last_val_acc1, best_val_acc1)
                best_val_acc5 = max(last_val_acc5, best_val_acc5)
                best_train_acc1 = max(train_acc1, best_train_acc1)
                best_train_acc5 = max(train_acc5, best_train_acc5)

                save = (((epoch + 1) % cfg.save_every)
                        == 0) and cfg.save_every > 0
                if save or epoch == cfg.epochs - 1:
                    if is_best:
                        cfg.logger.info(
                            f"==> best {last_val_acc1:.02f} saving at {ckpt_base_dir / 'model_best.pth'}"
                        )

                    net_utils.save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "arch": cfg.arch,
                            "state_dict": model.state_dict(),
                            "best_acc1": best_val_acc1,
                            "best_acc5": best_val_acc5,
                            "best_train_acc1": best_train_acc1,
                            "best_train_acc5": best_train_acc5,
                            "optimizer": optimizer.state_dict(),
                            "curr_acc1": last_val_acc1,
                            "curr_acc5": last_val_acc5,
                        },
                        is_best,
                        filename=ckpt_base_dir / f"epoch_{epoch}.state",
                        save=save or epoch == cfg.epochs - 1,
                    )

                elapsed_time = time.time() - start_time
                seconds_todo = (cfg.epochs - epoch) * (elapsed_time /
                                                       cfg.test_interval)
                estimated_time_complete = timedelta(seconds=int(seconds_todo))
                start_time = time.time()
                cfg.logger.info(
                    f"==> ETA: {estimated_time_complete}\tGPU-M: {gpu_info.gpu_mem_usage()}\tGPU-U: {gpu_info.gpu_utilization()}"
                )

                epoch_time.update((time.time() - end_epoch) / 60)
                progress_overall.display(epoch)
                progress_overall.write_to_tensorboard(writer,
                                                      prefix="diagnostics",
                                                      global_step=epoch)

                writer.add_scalar("test/lr", cur_lr, epoch)
                end_epoch = time.time()

            if cfg.world_size > 1:
                dist.barrier()
Exemplo n.º 3
0
def trn(cfg,model):



    cfg.logger.info(cfg)
    if cfg.seed is not None:
        random.seed(cfg.seed)
        torch.manual_seed(cfg.seed)
        torch.cuda.manual_seed(cfg.seed)
        torch.cuda.manual_seed_all(cfg.seed)

    train, validate_knn = get_trainer(cfg)

    if cfg.gpu is not None:
        cfg.logger.info("Use GPU: {} for training".format(cfg.gpu))

    # if cfg.pretrained:
    #     net_utils.load_pretrained(cfg.pretrained,cfg.multigpu[0], model)

    optimizer = get_optimizer(cfg, model)
    cfg.logger.info(f"=> Getting {cfg.set} dataset")

    dataset = getattr(data, cfg.set)(cfg)

    lr_policy = get_policy(cfg.lr_policy)(optimizer, cfg)

    if cfg.arch == 'SimSiam':
        # L = D(p1, z2) / 2 + D(p2, z1) / 2
        base_criterion = lambda bb1_z1_p1_emb, bb2_z2_p2_emb: simsiam.SimSaimLoss(bb1_z1_p1_emb[2], bb2_z2_p2_emb[1]) / 2 +\
                                                 simsiam.SimSaimLoss(bb2_z2_p2_emb[2], bb1_z1_p1_emb[1]) / 2
    elif cfg.arch == 'SimCLR':
        base_criterion = lambda z1,z2 : simclr.NT_XentLoss(z1, z2)
    else:
        raise NotImplemented

    run_base_dir, ckpt_base_dir, log_base_dir = path_utils.get_directories(cfg,cfg.gpu)
    _, zero_gpu_ckpt_base_dir, _ = path_utils.get_directories(cfg, 0)
    # if cfg.resume:
    saved_epochs = sorted(glob.glob(str(zero_gpu_ckpt_base_dir) + '/epoch_*.state'), key=os.path.getmtime)
    # assert len(epochs) < 2, 'Should be only one saved epoch -- the last one'
    if len(saved_epochs) > 0:
        cfg.resume = saved_epochs[-1]
        resume(cfg, model, optimizer)


    cfg.ckpt_base_dir = ckpt_base_dir

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(
        1, [epoch_time, validation_time, train_time], cfg, prefix="Overall Timing"
    )

    end_epoch = time.time()
    cfg.start_epoch = cfg.start_epoch or 0

    start_time = time.time()
    gpu_info = gpu_utils.GPU_Utils(gpu_index=cfg.gpu)

    cfg.logger.info('Start Training: Model conv 1 initialization {}'.format(torch.sum(model.module.backbone.conv1.weight)))
    # Start training

    for n,m in model.module.named_modules():
        if hasattr(m, "weight") and m.weight is not None:
            cfg.logger.info('{} ({}): {}'.format(n,type(m).__name__,m.weight.shape))


    criterion = base_criterion
    cfg.logger.info('Using Vanilla Criterion')
    for epoch in range(cfg.start_epoch, cfg.epochs):
        if cfg.world_size > 1:
            dataset.sampler.set_epoch(epoch)

        lr_policy(epoch, iteration=None)

        cur_lr = net_utils.get_lr(optimizer)

        start_train = time.time()
        train(dataset.trn_loader, model,criterion, optimizer, epoch, cfg, writer=writer)
        train_time.update((time.time() - start_train) / 60)


        if (epoch + 1) % cfg.test_interval == 0:
            if cfg.gpu == cfg.base_gpu:
                # evaluate on validation set
                start_validation = time.time()

                acc = validate_knn(dataset.trn_loader, dataset.val_loader, model.module, cfg, writer, epoch)

                validation_time.update((time.time() - start_validation) / 60)
                csv_utils.write_generic_result_to_csv(path=cfg.exp_dir,name=os.path.basename(cfg.exp_dir[:-1]),
                                                      epoch=epoch,
                                                      knn_acc=acc)

                save = (((epoch+1) % cfg.save_every) == 0) and cfg.save_every > 0
                if save or epoch == cfg.epochs - 1:
                    # if is_best:
                    # print(f"==> best {last_val_acc1:.02f} saving at {ckpt_base_dir / 'model_best.pth'}")

                    net_utils.save_checkpoint(
                        {
                            "epoch": epoch + 1,
                            "arch": cfg.arch,
                            "state_dict": model.state_dict(),
                            "ACC": acc,
                            "optimizer": optimizer.state_dict(),
                        },
                        is_best=False,
                        filename=ckpt_base_dir / f"epoch_{epoch:04d}.state",
                        save=save or epoch == cfg.epochs - 1,
                    )


                    elapsed_time = time.time() - start_time
                    seconds_todo = (cfg.epochs - epoch) * (elapsed_time / cfg.test_interval)
                    estimated_time_complete = timedelta(seconds=int(seconds_todo))
                    start_time = time.time()
                    cfg.logger.info(
                        f"==> ETA: {estimated_time_complete}\tGPU-M: {gpu_info.gpu_mem_usage()}\tGPU-U: {gpu_info.gpu_utilization()}")


                    epoch_time.update((time.time() - end_epoch) / 60)
                    progress_overall.display(epoch)
                    progress_overall.write_to_tensorboard(
                        writer, prefix="diagnostics", global_step=epoch
                    )

                    writer.add_scalar("test/lr", cur_lr, epoch)
                    end_epoch = time.time()


            if cfg.world_size > 1:
                # cfg.logger.info('GPU {} going into the barrier'.format(cfg.gpu))
                dist.barrier()
Exemplo n.º 4
0
def main_worker(args):
    # NEW: equivalent to MPI init.
    print("world size ", os.environ['OMPI_COMM_WORLD_SIZE'])
    print("rank ", os.environ['OMPI_COMM_WORLD_RANK'])
    torch.distributed.init_process_group(
        backend="nccl",
        init_method="env://",
        world_size=int(os.environ['OMPI_COMM_WORLD_SIZE']),
        rank=int(os.environ['OMPI_COMM_WORLD_RANK']))

    # NEW: lookup number of ranks in the job, and our rank
    args.world_size = torch.distributed.get_world_size()
    print("world size ", args.world_size)
    args.rank = torch.distributed.get_rank()
    print("rank ", args.rank)
    ngpus_per_node = torch.cuda.device_count()
    print("ngpus_per_node ", ngpus_per_node)
    local_rank = args.rank % ngpus_per_node
    print("local_rank ", local_rank)

    # NEW: Globalize variables
    global best_acc1
    global best_acc5
    global best_train_acc1
    global best_train_acc5

    #args.gpu = None
    # NEW: Specify gpu
    args.gpu = local_rank
    train, validate, modifier = get_trainer(args)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    model = get_model(args)

    # NEW: Distributed data
    #if args.distributed:
    args.batch_size = int(args.batch_size / ngpus_per_node)
    args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)

    #model = set_gpu(args, model)
    # NEW: Modified function for loading gpus on multinode setups
    model = lassen_set_gpu(args, model)

    if args.pretrained:
        pretrained(args, model)

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        #criterion = nn.CrossEntropyLoss().cuda()
        # NEW: Specify gpu
        criterion = nn.CrossEntropyLoss().cuda(args.gpu)
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if args.resume:
        best_acc1 = resume(args, model, optimizer)

    # Data loading code
    if args.evaluate:
        acc1, acc5 = validate(data.val_loader,
                              model,
                              criterion,
                              args,
                              writer=None,
                              epoch=args.start_epoch)

        return

    # Set up directories
    # NEW: Only do for main processor (one with global rank 0)
    if args.rank == 0:
        run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)
        args.ckpt_base_dir = ckpt_base_dir

    # NEW: Only do for main processor (one with global rank 0)
    if args.rank == 0:
        writer = SummaryWriter(log_dir=log_base_dir)
    else:
        writer = None

    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)

    # NEW: Only do for main processor (one with global rank 0)
    if args.rank == 0:
        progress_overall = ProgressMeter(
            1, [epoch_time, validation_time, train_time],
            prefix="Overall Timing")

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    # NEW: Only do for main processor (one with global rank 0)
    if args.rank == 0:
        save_checkpoint(
            {
                "epoch": 0,
                "arch": args.arch,
                "state_dict": model.state_dict(),
                "best_acc1": best_acc1,
                "best_acc5": best_acc5,
                "best_train_acc1": best_train_acc1,
                "best_train_acc5": best_train_acc5,
                "optimizer": optimizer.state_dict(),
                "curr_acc1": acc1 if acc1 else "Not evaluated",
            },
            False,
            filename=ckpt_base_dir / f"initial.state",
            save=False,
        )

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        # NEW: Distributed data
        #if args.distributed:
        data.train_sampler.set_epoch(epoch)
        data.val_sampler.set_epoch(epoch)

        lr_policy(epoch, iteration=None)
        #modifier(args, epoch, model)

        cur_lr = get_lr(optimizer)

        # train for one epoch
        start_train = time.time()
        train_acc1, train_acc5 = train(data.train_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       args,
                                       writer=writer)
        #train_acc1, train_acc5 = train(
        #    data.train_loader, model, criterion, optimizer, epoch, args, writer=None
        #)
        train_time.update((time.time() - start_train) / 60)

        # evaluate on validation set
        start_validation = time.time()

        # NEW: Only write values to tensorboard for main processor (one with global rank 0)
        if args.rank == 0:
            acc1, acc5 = validate(data.val_loader, model, criterion, args,
                                  writer, epoch)
        else:
            acc1, acc5 = validate(data.val_loader, model, criterion, args,
                                  None, epoch)

        validation_time.update((time.time() - start_validation) / 60)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)

        save = ((epoch % args.save_every) == 0) and args.save_every > 0

        # NEW: Only do for main processor (one with global rank 0)
        if args.rank == 0:
            if is_best or save or epoch == args.epochs - 1:
                if is_best:
                    print(
                        f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}"
                    )

                save_checkpoint(
                    {
                        "epoch": epoch + 1,
                        "arch": args.arch,
                        "state_dict": model.state_dict(),
                        "best_acc1": best_acc1,
                        "best_acc5": best_acc5,
                        "best_train_acc1": best_train_acc1,
                        "best_train_acc5": best_train_acc5,
                        "optimizer": optimizer.state_dict(),
                        "curr_acc1": acc1,
                        "curr_acc5": acc5,
                    },
                    is_best,
                    filename=ckpt_base_dir / f"epoch_most_recent.state",
                    save=save,
                )
                #filename=ckpt_base_dir / f"epoch_{epoch}.state",

        epoch_time.update((time.time() - end_epoch) / 60)

        # NEW: Only do for main processor (one with global rank 0)
        if args.rank == 0:
            progress_overall.display(epoch)
            progress_overall.write_to_tensorboard(writer,
                                                  prefix="diagnostics",
                                                  global_step=epoch)

            if args.conv_type == "SampleSubnetConv":
                count = 0
                sum_pr = 0.0
                for n, m in model.named_modules():
                    if isinstance(m, SampleSubnetConv):
                        # avg pr across 10 samples
                        pr = 0.0
                        for _ in range(10):
                            pr += ((torch.rand_like(m.clamped_scores) >=
                                    m.clamped_scores).float().mean().item())
                        pr /= 10.0
                        writer.add_scalar("pr/{}".format(n), pr, epoch)
                        sum_pr += pr
                        count += 1

                args.prune_rate = sum_pr / count
                writer.add_scalar("pr/average", args.prune_rate, epoch)

        # NEW: Only do for main processor (one with global rank 0)
        if args.rank == 0:
            writer.add_scalar("test/lr", cur_lr, epoch)

        end_epoch = time.time()

    # NEW: Only do for main processor (one with global rank 0)
    if args.rank == 0:
        write_result_to_csv(
            best_acc1=best_acc1,
            best_acc5=best_acc5,
            best_train_acc1=best_train_acc1,
            best_train_acc5=best_train_acc5,
            prune_rate=args.prune_rate,
            curr_acc1=acc1,
            curr_acc5=acc5,
            base_config=args.config,
            name=args.name,
        )
Exemplo n.º 5
0
Arquivo: main.py Projeto: zj15001/STR
def main_worker(args):
    args.gpu = None

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    model = get_model(args)
    model = set_gpu(args, model)

    # Set up directories
    run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)

    # Loading pretrained model
    if args.pretrained:
        pretrained(args, model)

        # Saving a DenseConv (nn.Conv2d) compatible model
        if args.dense_conv_model:
            print(
                f"==> DenseConv compatible model, saving at {ckpt_base_dir / 'model_best.pth'}"
            )
            save_checkpoint(
                {
                    "epoch": 0,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                },
                True,
                filename=ckpt_base_dir / f"epoch_pretrained.state",
                save=True,
            )
            return

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if args.resume:
        best_acc1 = resume(args, model, optimizer)

    # Evaulation of a model
    if args.evaluate:
        acc1, acc5 = validate(data.val_loader,
                              model,
                              criterion,
                              args,
                              writer=None,
                              epoch=args.start_epoch)
        return

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(1,
                                     [epoch_time, validation_time, train_time],
                                     prefix="Overall Timing")

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    save_checkpoint(
        {
            "epoch": 0,
            "arch": args.arch,
            "state_dict": model.state_dict(),
            "best_acc1": best_acc1,
            "best_acc5": best_acc5,
            "best_train_acc1": best_train_acc1,
            "best_train_acc5": best_train_acc5,
            "optimizer": optimizer.state_dict(),
            "curr_acc1": acc1 if acc1 else "Not evaluated",
        },
        False,
        filename=ckpt_base_dir / f"initial.state",
        save=False,
    )

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        lr_policy(epoch, iteration=None)
        cur_lr = get_lr(optimizer)

        # Gradual pruning in GMP experiments
        if args.conv_type == "GMPConv" and epoch >= args.init_prune_epoch and epoch <= args.final_prune_epoch:
            total_prune_epochs = args.final_prune_epoch - args.init_prune_epoch + 1
            for n, m in model.named_modules():
                if hasattr(m, 'set_curr_prune_rate'):
                    prune_decay = (
                        1 - ((args.curr_prune_epoch - args.init_prune_epoch) /
                             total_prune_epochs))**3
                    curr_prune_rate = m.prune_rate - (m.prune_rate *
                                                      prune_decay)
                    m.set_curr_prune_rate(curr_prune_rate)

        # train for one epoch
        start_train = time.time()
        train_acc1, train_acc5 = train(data.train_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       args,
                                       writer=writer)
        train_time.update((time.time() - start_train) / 60)

        # evaluate on validation set
        start_validation = time.time()
        acc1, acc5 = validate(data.val_loader, model, criterion, args, writer,
                              epoch)
        validation_time.update((time.time() - start_validation) / 60)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)

        save = ((epoch % args.save_every) == 0) and args.save_every > 0
        if is_best or save or epoch == args.epochs - 1:
            if is_best:
                print(
                    f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}"
                )

            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc1": best_acc1,
                    "best_acc5": best_acc5,
                    "best_train_acc1": best_train_acc1,
                    "best_train_acc5": best_train_acc5,
                    "optimizer": optimizer.state_dict(),
                    "curr_acc1": acc1,
                    "curr_acc5": acc5,
                },
                is_best,
                filename=ckpt_base_dir / f"epoch_{epoch}.state",
                save=save,
            )

        epoch_time.update((time.time() - end_epoch) / 60)
        progress_overall.display(epoch)
        progress_overall.write_to_tensorboard(writer,
                                              prefix="diagnostics",
                                              global_step=epoch)

        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

        # Storing sparsity and threshold statistics for STRConv models
        if args.conv_type == "STRConv":
            count = 0
            sum_sparse = 0.0
            for n, m in model.named_modules():
                if isinstance(m, STRConv):
                    sparsity, total_params, thresh = m.getSparsity()
                    writer.add_scalar("sparsity/{}".format(n), sparsity, epoch)
                    writer.add_scalar("thresh/{}".format(n), thresh, epoch)
                    sum_sparse += int(((100 - sparsity) / 100) * total_params)
                    count += total_params
            total_sparsity = 100 - (100 * sum_sparse / count)
            writer.add_scalar("sparsity/total", total_sparsity, epoch)
        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

    write_result_to_csv(
        best_acc1=best_acc1,
        best_acc5=best_acc5,
        best_train_acc1=best_train_acc1,
        best_train_acc5=best_train_acc5,
        prune_rate=args.prune_rate,
        curr_acc1=acc1,
        curr_acc5=acc5,
        base_config=args.config,
        name=args.name,
    )
    if args.conv_type == "STRConv":
        json_data = {}
        json_thres = {}
        for n, m in model.named_modules():
            if isinstance(m, STRConv):
                sparsity = m.getSparsity()
                json_data[n] = sparsity[0]
                sum_sparse += int(((100 - sparsity[0]) / 100) * sparsity[1])
                count += sparsity[1]
                json_thres[n] = sparsity[2]
        json_data["total"] = 100 - (100 * sum_sparse / count)
        if not os.path.exists("runs/layerwise_sparsity"):
            os.mkdir("runs/layerwise_sparsity")
        if not os.path.exists("runs/layerwise_threshold"):
            os.mkdir("runs/layerwise_threshold")
        with open("runs/layerwise_sparsity/{}.json".format(args.name),
                  "w") as f:
            json.dump(json_data, f)
        with open("runs/layerwise_threshold/{}.json".format(args.name),
                  "w") as f:
            json.dump(json_thres, f)
Exemplo n.º 6
0
def main_worker(args):
    args.gpu = None
    train, validate, modifier = get_trainer(args)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    # create model and optimizer
    model = get_model(args)
    model = set_gpu(args, model)

    if args.pretrained:
        pretrained(args, model)

    # SJT modification:
    if args.exp_mode:  # pretraining/pruning/funetuning
        exp_mode = args.exp_mode
        if exp_mode == "pretraining":
            # YHT modefication, setting the pruning rate to 0
            print(
                "Figure out your exp_mode is pretraining, setting prune-rate to 0"
            )
            args.prune_rate = 0
            unfreeze_model_weights(model)
            freeze_model_subnet(model)
    # End of SJT modification

    optimizer = get_optimizer(args, model)
    data = get_dataset(args)
    lr_policy = get_policy(args.lr_policy)(optimizer, args)

    if args.label_smoothing is None:
        criterion = nn.CrossEntropyLoss().cuda()
    else:
        criterion = LabelSmoothing(smoothing=args.label_smoothing)

    # optionally resume from a checkpoint
    best_acc1 = 0.0
    best_acc5 = 0.0
    best_train_acc1 = 0.0
    best_train_acc5 = 0.0

    if args.resume:
        # SJT modification
        if args.exp_mode:
            if args.exp_mode == "pruning":
                optimizer = resume_pruning(args, model)
            else:  # Only can be "finetuning"
                if args.exp_mode != "finetuning":
                    print(
                        "resume method should be combined with pruning/finetuning exp_mode together!"
                    )
                    return
                else:
                    optimizer = resume_finetuning(args, model)
                    # YHT: not sure whether it is needed
                    #lr_policy = get_policy(args.lr_policy)(optimizer, args)
                    #print("#####################DEBUG PRINT : VALIDATE FIRST#####################")
                    #validate(data.val_loader, model, criterion, args, writer= None, epoch=args.start_epoch)
        else:
            best_acc1 = resume(args, model, optimizer)
        # End of SJT modification
    else:
        # YHT modification
        if args.exp_mode:
            if args.exp_mode == "finetuning":
                #here, we suppose the user want to use init prun-rate vector to do the finetuning(subnetwork)
                print(
                    "Using finetuning mode without resume, which is supposed to be innit fientune."
                )
                optimizer = resume_finetuning(args, model)
                # YHT: not sure whether it is needed
                lr_policy = get_policy(args.lr_policy)(optimizer, args)
        # End of modification

    # Data loading code
    if args.evaluate:
        acc1, acc5 = validate(data.val_loader,
                              model,
                              criterion,
                              args,
                              writer=None,
                              epoch=args.start_epoch)

        return

    # Set up directories
    run_base_dir, ckpt_base_dir, log_base_dir = get_directories(args)
    args.ckpt_base_dir = ckpt_base_dir

    writer = SummaryWriter(log_dir=log_base_dir)
    epoch_time = AverageMeter("epoch_time", ":.4f", write_avg=False)
    validation_time = AverageMeter("validation_time", ":.4f", write_avg=False)
    train_time = AverageMeter("train_time", ":.4f", write_avg=False)
    progress_overall = ProgressMeter(1,
                                     [epoch_time, validation_time, train_time],
                                     prefix="Overall Timing")

    end_epoch = time.time()
    args.start_epoch = args.start_epoch or 0
    acc1 = None

    # Save the initial state
    save_checkpoint(
        {
            "epoch": 0,
            "arch": args.arch,
            "state_dict": model.state_dict(),
            "best_acc1": best_acc1,
            "best_acc5": best_acc5,
            "best_train_acc1": best_train_acc1,
            "best_train_acc5": best_train_acc5,
            "optimizer": optimizer.state_dict(),
            "curr_acc1": acc1 if acc1 else "Not evaluated",
        },
        False,
        filename=ckpt_base_dir / f"initial.state",
        save=False,
    )

    if args.gp_warm_up:
        record_prune_rate = args.prune_rate
    if args.print_more:
        print_global_layerwise_prune_rate(model, args.prune_rate)

    # YHT modification May 20
    # till here, we have every prune-rate is accurate
    # Now we need to create mask if prandom is true using
    if args.prandom:
        make_prandom_mask(model)
    # End of modification

    # Start training
    for epoch in range(args.start_epoch, args.epochs):
        lr_policy(epoch, iteration=None)
        modifier(args, epoch, model)
        cur_lr = get_lr(optimizer)
        if args.print_more:
            print("In epoch{epoch}, lr = {cur_lr}")
        # train for one epoch
        start_train = time.time()
        # WHN modeification add global pruning
        if args.pscale == "global":
            if args.gp_warm_up:
                if epoch < args.gp_warm_up_epochs:
                    args.prune_rate = 0
                else:
                    args.prune_rate = record_prune_rate
            if not args.prandom:
                args.score_threshold = get_global_score_threshold(
                    model, args.prune_rate)

        # YHT modification
        if args.print_more:
            print_global_layerwise_prune_rate(model, args.prune_rate)
        # End of modification

        train_acc1, train_acc5 = train(data.train_loader,
                                       model,
                                       criterion,
                                       optimizer,
                                       epoch,
                                       args,
                                       writer=writer)
        train_time.update((time.time() - start_train) / 60)

        # evaluate on validation set
        start_validation = time.time()
        # if random labeled, evaluate on training set (by yty)
        if args.shuffle:
            acc1, acc5 = train_acc1, train_acc5
        else:
            acc1, acc5 = validate(data.val_loader, model, criterion, args,
                                  writer, epoch)

        validation_time.update((time.time() - start_validation) / 60)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        best_acc5 = max(acc5, best_acc5)
        best_train_acc1 = max(train_acc1, best_train_acc1)
        best_train_acc5 = max(train_acc5, best_train_acc5)

        save = ((epoch % args.save_every) == 0) and args.save_every > 0
        if is_best or save or epoch == args.epochs - 1:
            if is_best:
                print(
                    f"==> New best, saving at {ckpt_base_dir / 'model_best.pth'}"
                )

            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc1": best_acc1,
                    "best_acc5": best_acc5,
                    "best_train_acc1": best_train_acc1,
                    "best_train_acc5": best_train_acc5,
                    "optimizer": optimizer.state_dict(),
                    "curr_acc1": acc1,
                    "curr_acc5": acc5,
                },
                is_best,
                filename=ckpt_base_dir / f"epoch_{epoch}.state",
                save=save,
            )

        epoch_time.update((time.time() - end_epoch) / 60)
        progress_overall.display(epoch)
        progress_overall.write_to_tensorboard(writer,
                                              prefix="diagnostics",
                                              global_step=epoch)

        if args.conv_type == "SampleSubnetConv":
            count = 0
            sum_pr = 0.0
            for n, m in model.named_modules():
                if isinstance(m, SampleSubnetConv):
                    # avg pr across 10 samples
                    pr = 0.0
                    for _ in range(10):
                        pr += ((torch.rand_like(m.clamped_scores) >=
                                m.clamped_scores).float().mean().item())
                    pr /= 10.0
                    writer.add_scalar("pr/{}".format(n), pr, epoch)
                    sum_pr += pr
                    count += 1

            args.prune_rate = sum_pr / count
            writer.add_scalar("pr/average", args.prune_rate, epoch)

        writer.add_scalar("test/lr", cur_lr, epoch)
        end_epoch = time.time()

    write_result_to_csv(
        best_acc1=best_acc1,
        best_acc5=best_acc5,
        best_train_acc1=best_train_acc1,
        best_train_acc5=best_train_acc5,
        prune_rate=args.prune_rate,
        curr_acc1=acc1,
        curr_acc5=acc5,
        base_config=args.config,
        name=args.name,
    )
Exemplo n.º 7
0
def train(params):
    # history logs
    best_valid_loss = float('inf')
    log_info = ""
    hist_train_loss, hist_valid_loss, hist_lr = [], [], []

    # dataset
    # ds_csv = "/home/ysheng/Dataset/new_dataset/meta_data.csv"
    # ds_folder = './dataset/new_dataset'
    ds_folder = params.ds_folder
    train_set = SSN_Dataset(ds_folder, True)
    train_dataloder = DataLoader(train_set,
                                 batch_size=min(len(train_set),
                                                params.batch_size),
                                 shuffle=True,
                                 num_workers=params.workers,
                                 drop_last=True)
    valid_set = SSN_Dataset(ds_folder, False)
    valid_dataloader = DataLoader(valid_set,
                                  batch_size=min(len(valid_set),
                                                 params.batch_size),
                                  shuffle=False,
                                  num_workers=params.workers,
                                  drop_last=True)

    best_weight = ''
    # model & optimizer & scheduler & loss function
    if not params.from_baseline:
        input_channel = params.input_channel
        model = Relight_SSN(input_channel, 1)  # input is mask + human
        model.to(device)
    else:
        model = Relight_SSN(1, 1)
        model.to(device)
        baseline_checkpoint = torch.load("weights/human_baseline.pt",
                                         map_location=device)
        model.load_state_dict(baseline_checkpoint['model_state_dict'])

    if params.tbaseline:
        params.input_channel = 2
        model = baseline_2_tbaseline(model)
        model.to(device)

    if params.touch_loss:
        params.input_channel = 1
        model = baseline_2_touchloss(model)
        model.to(device)

    # resume from last saved points
    if params.resume:
        best_weight = os.path.join("weights", params.weight_file)
        checkpoint = torch.load(best_weight, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        best_valid_loss = checkpoint['best_loss']
        hist_train_loss = checkpoint['hist_train_loss']
        hist_valid_loss = checkpoint['hist_valid_loss']
        if 'hist_lr' in checkpoint.keys():
            hist_lr = checkpoint['hist_lr']
        print("resuming from: {}".format(best_weight))
        del checkpoint

        # tensorboard writer update history
        for i in range(0, len(hist_train_loss)):
            tensorboard_plot_loss("history train loss",
                                  hist_train_loss[:i + 1], writer)

        for i in range(0, len(hist_valid_loss)):
            tensorboard_plot_loss("history valid loss",
                                  hist_valid_loss[:i + 1], writer)

    if params.relearn:
        best_valid_loss = float('inf')

    print(torch.cuda.device_count())
    # test multiple GPUs
    if torch.cuda.device_count() > 1 and params.multi_gpu:
        print("Let's use ", torch.cuda.device_count(), "GPUs")
        model = nn.DataParallel(model)

    optimizer = set_model_optimizer(model, params.weight_decay)
    set_lr(optimizer, params.lr)
    print("Current LR: {}".format(get_lr(optimizer)))
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=params.patience)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=30,
                                          gamma=0.1,
                                          last_epoch=-1)

    # training states
    train_loss, valid_loss = [], []

    # training iterations
    for epoch in range(params.epochs):
        # training
        cur_train_loss = training_iteration(model, train_dataloder, optimizer,
                                            train_loss, epoch)

        #         # validation
        cur_valid_loss = validation_iteration(model, valid_dataloader,
                                              valid_loss, epoch)

        if params.use_schedule:
            scheduler.step()

        log_info += "Current epoch: {} Learning Rate: {}  <br>".format(
            epoch, get_lr(optimizer))
        tensorboard_log(log_info, writer, step=epoch)

        hist_train_loss.append(cur_train_loss)
        hist_valid_loss.append(cur_valid_loss)

        tensorboard_plot_loss("history train loss", hist_train_loss, writer)
        tensorboard_plot_loss("history valid loss", hist_valid_loss, writer)

        log_info += "Epoch: {} training loss: {}, valid loss: {}  <br>".format(
            epoch, cur_train_loss, cur_valid_loss)
        # save results
        if best_valid_loss > cur_valid_loss:
            log_info += "<br> ---------- Exp: {} Find better loss: {} at {} --------  <br>".format(
                exp_name, cur_valid_loss, datetime.datetime.now())
            tensorboard_log(log_info, writer, step=epoch)

            best_valid_loss = cur_valid_loss

            outfname = '{}_{}.pt'.format(exp_name, get_cur_time_stamp())
            best_weight = save_model("weights", model, optimizer, epoch,
                                     best_valid_loss, outfname,
                                     hist_train_loss, hist_valid_loss, hist_lr,
                                     params)

        outfname = '{}.pt'.format(exp_name)
        save_model("weights", model, optimizer, epoch, best_valid_loss,
                   outfname, hist_train_loss, hist_valid_loss, hist_lr, params)

        # saving loss to local directory
        plt.figure()
        plt.plot(hist_train_loss, label='train loss')
        plt.plot(hist_train_loss, label='valid loss')
        plt.legend()
        plt.savefig('{}_loss_plot.png'.format(params.exp_name))
        plt.close()

        # termination
        if get_lr(optimizer) < 1e-7:
            break

    print("Training finished")
    return best_weight