Exemple #1
0
def main():
    if not torch.cuda.is_available():
        print('no gpu device available')
        sys.exit(1)

    writer = None
    num_gpus = torch.cuda.device_count()
    np.random.seed(args.seed)
    args.gpu = args.local_rank % num_gpus
    args.nprocs = num_gpus
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.deterministic = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    if args.local_rank == 0:
        args.exp = datetime.datetime.now().strftime("%YY_%mM_%dD_%HH") + "_" + \
            "{:04d}".format(random.randint(0, 1000))

    print('gpu device = %d' % args.gpu)
    print("args = %s", args)

    if args.model_type == "dynamic":
        model = dynamic_resnet20()
    elif args.model_type == "independent":
        model = Independent_resnet20()
    elif args.model_type == "slimmable":
        model = mutableResNet20()
    elif args.model_type == "original":
        model = resnet20()
    else:
        print("Not Implement")

    # model = resnet20()
    model = model.cuda(args.gpu)

    if num_gpus > 1:
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

        args.world_size = torch.distributed.get_world_size()
        args.batch_size = args.batch_size // args.world_size

    # criterion_smooth = CrossEntropyLabelSmooth(args.classes, args.label_smooth)
    # criterion_smooth = criterion_smooth.cuda()
    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)
    soft_criterion = CrossEntropyLossSoft()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # scheduler = torch.optim.lr_scheduler.LambdaLR(
    #     optimizer, lambda step: (1.0-step/args.total_iters), last_epoch=-1)
    # a_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    #     optimizer, T_0=5)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    # a_scheduler = torch.optim.lr_scheduler.LambdaLR(
    #     optimizer, lambda epoch: 1 - (epoch / args.epochs))
    a_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=[60, 120, 160], last_epoch=-1)  # !!
    scheduler = GradualWarmupScheduler(optimizer,
                                       1,
                                       total_epoch=5,
                                       after_scheduler=a_scheduler)

    if args.local_rank == 0:
        writer = SummaryWriter(
            "./runs/%s-%05d" %
            (time.strftime("%m-%d", time.localtime()), random.randint(0, 100)))

    # Prepare data
    train_loader = get_train_loader(args.batch_size, args.local_rank,
                                    args.num_workers)
    # 原来跟train batch size一样,现在修改小一点 ,
    val_loader = get_val_loader(args.batch_size, args.num_workers)

    archloader = ArchLoader("data/Track1_final_archs.json")

    for epoch in range(args.epochs):
        train(train_loader, val_loader, optimizer, scheduler, model,
              archloader, criterion, soft_criterion, args, args.seed, epoch,
              writer)

        scheduler.step()
        if (epoch + 1) % args.report_freq == 0:
            top1_val, top5_val, objs_val = infer(train_loader, val_loader,
                                                 model, criterion, archloader,
                                                 args, epoch)

            if args.local_rank == 0:
                # model
                if writer is not None:
                    writer.add_scalar("Val/loss", objs_val, epoch)
                    writer.add_scalar("Val/acc1", top1_val, epoch)
                    writer.add_scalar("Val/acc5", top5_val, epoch)

                save_checkpoint({
                    'state_dict': model.state_dict(),
                }, epoch, args.exp)
def main():

    if not torch.cuda.is_available():
        print('no gpu device available')
        sys.exit(1)

    writer = None
    num_gpus = torch.cuda.device_count()
    np.random.seed(args.seed)
    args.gpu = args.local_rank % num_gpus
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.deterministic = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    print('gpu device = %d' % args.gpu)
    print("args = %s", args)

    torch.distributed.init_process_group(backend='nccl', init_method='env://')

    args.world_size = torch.distributed.get_world_size()
    args.batch_size = args.batch_size // args.world_size

    criterion_smooth = CrossEntropyLabelSmooth(args.classes, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()

    model = mutableResNet20()

    model = model.cuda(args.gpu)
    model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.local_rank],
        output_device=args.local_rank,
        find_unused_parameters=True)

    # all_parameters = model.parameters()
    # weight_parameters = []
    # for pname, p in model.named_parameters():
    #     if p.ndimension() == 4 or 'classifier.0.weight' in pname or 'classifier.0.bias' in pname:
    #         weight_parameters.append(p)
    # weight_parameters_id = list(map(id, weight_parameters))
    # other_parameters = list(
    #     filter(lambda p: id(p) not in weight_parameters_id, all_parameters))
    # optimizer = torch.optim.SGD(
    #     [{'params': other_parameters},
    #      {'params': weight_parameters, 'weight_decay': args.weight_decay}],
    #     args.learning_rate,
    #     momentum=args.momentum,
    # )

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    args.total_iters = args.epochs * per_epoch_iters  # // 16  # 16 代表是每个子网的个数

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lambda step: (1.0 - step / args.total_iters), last_epoch=-1)

    if args.local_rank == 0:
        writer = SummaryWriter(
            "./runs/%s-%05d" %
            (time.strftime("%m-%d", time.localtime()), random.randint(0, 100)))

    # Prepare data
    train_loader = get_train_loader(args.batch_size, args.local_rank,
                                    args.num_workers, args.total_iters)
    train_dataprovider = DataIterator(train_loader)
    val_loader = get_val_loader(args.batch_size, args.num_workers)
    val_dataprovider = DataIterator(val_loader)

    archloader = ArchLoader("data/Track1_final_archs.json")

    train(train_dataprovider, val_dataprovider, optimizer, scheduler, model,
          archloader, criterion_smooth, args, val_iters, args.seed, writer)
def main():
    args = get_args()

    num_gpus = torch.cuda.device_count()
    np.random.seed(args.seed)
    args.gpu = args.local_rank % num_gpus
    torch.cuda.set_device(args.gpu)

    cudnn.benchmark = True
    cudnn.deterministic = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    args.world_size = torch.distributed.get_world_size()
    args.batch_size = args.batch_size // args.world_size

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}{:02}{}'.format(local_time.tm_year % 2000,
                                                  local_time.tm_mon, t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    # archLoader
    # arch_loader=ArchLoader(args.path)
    arch_dataset = ArchDataSet(args.path)
    arch_sampler = DistributedSampler(arch_dataset)
    arch_dataloader = torch.utils.data.DataLoader(arch_dataset,
                                                  batch_size=1,
                                                  shuffle=False,
                                                  num_workers=6,
                                                  pin_memory=False,
                                                  sampler=arch_sampler)

    val_dataset = get_val_dataset()
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=6,
                                             pin_memory=False)

    print('load data successfully')

    model = mutableResNet20()

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    model = model.cuda(args.gpu)
    model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.local_rank],
        output_device=args.local_rank,
        find_unused_parameters=True)

    print("load model successfully")

    print('load from latest checkpoint')
    lastest_model = args.weights
    if lastest_model is not None:
        checkpoint = torch.load(lastest_model,
                                map_location=None if True else 'cpu')
        model.load_state_dict(checkpoint['state_dict'], strict=True)

    # 参数设置
    args.loss_function = criterion_smooth
    args.val_dataloader = val_loader

    print("start to validate model")

    validate(model, args, arch_loader=arch_dataloader)
def main():
    args = get_args()

    # archLoader
    arch_loader = ArchLoader(args.path)

    # Log
    log_format = '[%(asctime)s] %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.INFO,
                        format=log_format,
                        datefmt='%m-%d %I:%M:%S')
    t = time.time()
    local_time = time.localtime(t)
    if not os.path.exists('./log'):
        os.mkdir('./log')
    fh = logging.FileHandler(
        os.path.join('log/train-{}-{:02}-{:02}-{:.3f}'.format(
            local_time.tm_year % 2000, local_time.tm_mon, local_time.tm_mday,
            t)))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)

    use_gpu = False
    if torch.cuda.is_available():
        use_gpu = True

    train_dataset, val_dataset = get_dataset('cifar100')

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.num_workers,
                                             pin_memory=True)

    model = mutableResNet20()

    logging.info('load model successfully')

    optimizer = torch.optim.SGD(get_parameters(model),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion_smooth = CrossEntropyLabelSmooth(1000, 0.1)

    if use_gpu:
        model = nn.DataParallel(model)
        loss_function = criterion_smooth.cuda()
        device = torch.device("cuda")
    else:
        loss_function = criterion_smooth
        device = torch.device("cpu")

    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
    #                                               lambda step: (1.0-step/args.total_iters) if step <= args.total_iters else 0, last_epoch=-1)
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer, T_max=200)

    model = model.to(device)

    all_iters = 0

    if args.auto_continue:  # 自动进行??
        lastest_model, iters = get_lastest_model()
        if lastest_model is not None:
            all_iters = iters
            checkpoint = torch.load(lastest_model,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint['state_dict'], strict=True)
            logging.info('load from checkpoint')
            for i in range(iters):
                scheduler.step()

    # 参数设置
    args.optimizer = optimizer
    args.loss_function = loss_function
    args.scheduler = scheduler
    args.train_loader = train_loader
    args.val_loader = val_loader

    if args.eval:
        if args.eval_resume is not None:
            checkpoint = torch.load(args.eval_resume,
                                    map_location=None if use_gpu else 'cpu')
            model.load_state_dict(checkpoint, strict=True)
            validate(model,
                     device,
                     args,
                     all_iters=all_iters,
                     arch_loader=arch_loader)
        exit(0)

    # warmup weights
    if args.warmup > 0:
        logging.info("begin warmup weights")
        while all_iters < args.warmup:
            all_iters = train_supernet(model,
                                       device,
                                       args,
                                       bn_process=False,
                                       all_iters=all_iters)

        validate(model,
                 device,
                 args,
                 all_iters=all_iters,
                 arch_loader=arch_loader)

    while all_iters < args.total_iters:
        logging.info("=" * 50)
        all_iters = train_subnet(model,
                                 device,
                                 args,
                                 bn_process=False,
                                 all_iters=all_iters,
                                 arch_loader=arch_loader)

        if all_iters % 200 == 0:
            logging.info("validate iter {}".format(all_iters))

            validate(model,
                     device,
                     args,
                     all_iters=all_iters,
                     arch_loader=arch_loader)