def main():
    '''以下作用1.存储损失 2.gpu,cpu的选择'''
    torch.manual_seed(args.seed)  # args.seed默认为1,保证每次结果可以复现
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu:
        use_gpu = False
    if not args.evaluate:  # 训练模式,损失存在log_train.txt中,测试模式存在log_test.txt
        sys.stdout = Logger(osp.join(args.save_dir,
                                     'log_train.txt'))  # 存放log文件
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))  # 打印所有arg中使用的参数

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))  # 打印使用的GPU
        cudnn.benchmark = True  # 使用gpu
        torch.cuda.manual_seed_all(args.seed)  # 传入随机数的总数,牵扯到随机初始化,加上这个保证结果稳定复现
    else:
        print("Currently using CPU (GPU is highly recommended)")

    # 1.制定数据集 2.制定transform 3.制作trainloader吐出tensor进行训练
    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_img_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
    )  # args.root的根目录,name数据集,split_id 针对CHK03数据集,这里没有

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),  # 只有train时,才用数据增广
        T.RandomHorizontalFlip(),
        T.ToTensor(),  # 将图片格式转化为电脑识别出来的tensor格式
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),  # 图像像素值的归一化,方便计算,约定俗成的常数
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),  # 因为test数据有大有小,所以要resize成一样大
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False
    # 以下为三个dataloader,其中trainloader需要数据增广,其他不用
    trainloader = DataLoader(
        ImageDataset(dataset.train,
                     transform=transform_train),  # transform有增广,
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,  # num_workers读数据的现成
        pin_memory=pin_memory,
        drop_last=True,
        # pin_memory节省内存模式打开 drop_last:一批有100张图片,batch=32,so 一共分成3组,剩下的4张就丢弃,否则batch不一样大会报错
    )

    queryloader = DataLoader(
        ImageDataset(dataset.query, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,  # shuffle不用打乱
        pin_memory=pin_memory,
        drop_last=False,  # test时,每张图都保留
    )

    galleryloader = DataLoader(
        ImageDataset(dataset.gallery, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    # model的选择,通过models.init_model选择resnet类对象,并给参数
    print("Initializing model: {}".format(
        args.arch))  # args.arch用来选择resnet50框架
    model = models.init_model(
        name=args.arch,
        num_classes=dataset.num_train_pids,
        loss={'xent'},
        use_gpu=use_gpu)  # dataset.num_train_pids多少个pid多少个分类
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    # 1.loss 2.optimizer 3.学习率的控制 4.resume重载数据 5.model包装成并行,放到多卡上训练
    criterion = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids,
                                        use_gpu=use_gpu)
    optimizer = init_optim(
        args.optim, model.parameters(), args.lr,
        args.weight_decay)  # model.parameters()更新模型的所有参数,weight_decay正则
    # #用nn.Sequential来包裹model的两层参数,每次更新只更新这两层
    # optimizer = init_optim(args.optim, nn.Sequential([model.conv1,model.conv2]), args.lr, args.weight_decay)

    if args.stepsize > 0:
        scheduler = lr_scheduler.StepLR(
            optimizer, step_size=args.stepsize, gamma=args.gamma
        )  # args.stepsize多少个epoch学习率降一次学习率阶梯式衰减,args.gamma衰减倍数
    start_epoch = args.start_epoch  # 从第几个epoch训练,中间你干别的事可以暂停训练

    if args.resume:  # 训练的读档,根据之前产生的数据,加载到模型中,恢复模型
        print("Loading checkpoint from '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)  # 读路径下的参数
        model.load_state_dict(checkpoint['state_dict'])  # 1.将权重的参数加载到模型中
        start_epoch = checkpoint['epoch']  # 2. 第几个epoch的参数

    if use_gpu:
        model = nn.DataParallel(model).cuda()  # 将model包装成并行,放到多卡上训练
        # model.module.parameters()#多卡的parameters

    if args.evaluate:  # 如果只是测试的话
        print("Evaluate only")
        test(model, queryloader, galleryloader, use_gpu)
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0

    ##
    print("==> Start training")
    for epoch in range(start_epoch, args.max_epoch):  # 从起始的epoch到最大epoch
        start_train_time = time.time()
        train(epoch, model, criterion, optimizer, trainloader,
              use_gpu)  # 下面train函数的使用
        train_time += round(time.time() - start_train_time)

        if args.stepsize > 0: scheduler.step()  # 加了这句学习率都会衰减

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, queryloader, galleryloader, use_gpu)
            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            # 用checkpint函数封装train的数据
            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) +
                         '.pth.tar'))  # 训练后的数据存到这个路径

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    if not args.evaluate:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(args.save_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_img_dataset(
        root=args.root,
        name=args.dataset,
        split_id=args.split_id,
        cuhk03_labeled=args.cuhk03_labeled,
        cuhk03_classic_split=args.cuhk03_classic_split,
    )

    transform_train = T.Compose([
        T.Random2DTranslation(args.height, args.width),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    transform_test = T.Compose([
        T.Resize((args.height, args.width)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    pin_memory = True if use_gpu else False

    trainloader = DataLoader(
        ImageDataset(dataset.train, transform=transform_train),
        batch_size=args.train_batch,
        shuffle=True,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )

    queryloader = DataLoader(
        ImageDataset(dataset.query, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    galleryloader = DataLoader(
        ImageDataset(dataset.gallery, transform=transform_test),
        batch_size=args.test_batch,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=False,
    )

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              num_classes=dataset.num_train_pids,
                              loss={'ring'},
                              use_gpu=use_gpu)
    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    criterion_xent = CrossEntropyLabelSmooth(
        num_classes=dataset.num_train_pids, use_gpu=use_gpu)
    criterion_ring = RingLoss(args.weight_ring)
    if use_gpu: criterion_ring = criterion_ring.cuda()

    params = list(model.parameters()) + list(criterion_ring.parameters())
    optimizer = init_optim(args.optim, params, args.lr, args.weight_decay)
    if args.stepsize > 0:
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=args.stepsize,
                                        gamma=args.gamma)
    start_epoch = args.start_epoch

    if args.resume:
        print("Loading checkpoint from '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        start_epoch = checkpoint['epoch']

    if use_gpu:
        model = nn.DataParallel(model).cuda()

    if args.evaluate:
        print("Evaluate only")
        test(model, queryloader, galleryloader, use_gpu)
        return

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    for epoch in range(start_epoch, args.max_epoch):
        start_train_time = time.time()
        train(epoch, model, criterion_xent, criterion_ring, optimizer,
              trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        if args.stepsize > 0: scheduler.step()

        if (epoch + 1) > args.start_eval and args.eval_step > 0 and (
                epoch + 1) % args.eval_step == 0 or (epoch +
                                                     1) == args.max_epoch:
            print("==> Test")
            rank1 = test(model, queryloader, galleryloader, use_gpu)
            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))