def main():
    if not torch.cuda.is_available():
        print('no gpu device available')
        sys.exit(1)

    writer = None
    num_gpus = torch.cuda.device_count()
    np.random.seed(args.seed)
    args.gpu = args.local_rank % num_gpus
    args.nprocs = num_gpus
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.deterministic = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    if args.local_rank == 0:
        args.exp = datetime.datetime.now().strftime("%YY_%mM_%dD_%HH") + "_" + \
            "{:04d}".format(random.randint(0, 1000))

    print('gpu device = %d' % args.gpu)
    print("args = %s", args)
    model = resnet20()
    # model = mutableResNet20()
    # model = dynamic_resnet20()
    model = model.cuda(args.gpu)

    if num_gpus > 1:
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=False)

        args.world_size = torch.distributed.get_world_size()
        args.batch_size = args.batch_size // args.world_size

    # criterion_smooth = CrossEntropyLabelSmooth(args.classes, args.label_smooth)
    # criterion_smooth = criterion_smooth.cuda()
    criterion = torch.nn.CrossEntropyLoss().cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # scheduler = torch.optim.lr_scheduler.LambdaLR(
    #     optimizer, lambda step: (1.0-step/args.total_iters), last_epoch=-1)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    #     optimizer, T_0=5)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    # scheduler = torch.optim.lr_scheduler.LambdaLR(
    #     optimizer, lambda epoch: 1 - (epoch / args.epochs))
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[100, 150],
                                                     last_epoch=-1)

    if args.local_rank == 0:
        writer = SummaryWriter(
            "./runs/%s-%05d" %
            (time.strftime("%m-%d", time.localtime()), random.randint(0, 100)))

    # Prepare data
    train_loader = get_train_loader(args.batch_size, args.local_rank,
                                    args.num_workers)
    # 原来跟train batch size一样,现在修改小一点 , 同时修改val_iters
    val_loader = get_val_loader(args.batch_size, args.num_workers)

    archloader = ArchLoader("data/Track1_final_archs.json")

    for epoch in range(args.epochs):
        train(train_loader, val_loader, optimizer, scheduler, model,
              archloader, criterion, args, val_iters, args.seed, epoch, writer)

        scheduler.step()
        if (epoch + 1) % args.report_freq == 0:
            top1_val, top5_val, objs_val = infer(train_loader, val_loader,
                                                 model, criterion, val_iters,
                                                 archloader, args)

            if args.local_rank == 0:
                # model
                if writer is not None:
                    writer.add_scalar("Val/loss", objs_val, epoch)
                    writer.add_scalar("Val/acc1", top1_val, epoch)
                    writer.add_scalar("Val/acc5", top5_val, epoch)

                save_checkpoint({
                    'state_dict': model.state_dict(),
                }, epoch, args.exp)
def main():

    if not torch.cuda.is_available():
        print('no gpu device available')
        sys.exit(1)

    writer = None
    num_gpus = torch.cuda.device_count()
    np.random.seed(args.seed)
    args.gpu = args.local_rank % num_gpus
    torch.cuda.set_device(args.gpu)
    cudnn.benchmark = True
    cudnn.deterministic = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    print('gpu device = %d' % args.gpu)
    print("args = %s", args)

    torch.distributed.init_process_group(backend='nccl', init_method='env://')

    args.world_size = torch.distributed.get_world_size()
    args.batch_size = args.batch_size // args.world_size

    criterion_smooth = CrossEntropyLabelSmooth(args.classes, args.label_smooth)
    criterion_smooth = criterion_smooth.cuda()

    model = mutableResNet20()

    model = model.cuda(args.gpu)
    model = torch.nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.local_rank],
        output_device=args.local_rank,
        find_unused_parameters=True)

    # all_parameters = model.parameters()
    # weight_parameters = []
    # for pname, p in model.named_parameters():
    #     if p.ndimension() == 4 or 'classifier.0.weight' in pname or 'classifier.0.bias' in pname:
    #         weight_parameters.append(p)
    # weight_parameters_id = list(map(id, weight_parameters))
    # other_parameters = list(
    #     filter(lambda p: id(p) not in weight_parameters_id, all_parameters))
    # optimizer = torch.optim.SGD(
    #     [{'params': other_parameters},
    #      {'params': weight_parameters, 'weight_decay': args.weight_decay}],
    #     args.learning_rate,
    #     momentum=args.momentum,
    # )

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    args.total_iters = args.epochs * per_epoch_iters  # // 16  # 16 代表是每个子网的个数

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lambda step: (1.0 - step / args.total_iters), last_epoch=-1)

    if args.local_rank == 0:
        writer = SummaryWriter(
            "./runs/%s-%05d" %
            (time.strftime("%m-%d", time.localtime()), random.randint(0, 100)))

    # Prepare data
    train_loader = get_train_loader(args.batch_size, args.local_rank,
                                    args.num_workers, args.total_iters)
    train_dataprovider = DataIterator(train_loader)
    val_loader = get_val_loader(args.batch_size, args.num_workers)
    val_dataprovider = DataIterator(val_loader)

    archloader = ArchLoader("data/Track1_final_archs.json")

    train(train_dataprovider, val_dataprovider, optimizer, scheduler, model,
          archloader, criterion_smooth, args, val_iters, args.seed, writer)