def main():
    torch.manual_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
    use_gpu = torch.cuda.is_available()
    if args.use_cpu: use_gpu = False

    # set a learning rate
    if args.lr_factor == -1:
        args.lr_factor = random()
    args.lr = args.lr_factor * 10**-args.lr_base
    args.lr *= len(args.gpu_devices.split(','))
    print(f"Choose learning rate {args.lr}")

    sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'), mode='a')
    print("==========\nArgs:{}\n==========".format(args))

    #assert torch.distributed.is_available()
    #print("Initializing DDP by nccl-tcp({}) rank({}) world_size({})".format(args.init_method, args.rank, args.world_size))
    #dist.init_process_group(backend='nccl', init_method=args.init_method, rank=args.rank, world_size=args.world_size)

    if use_gpu:
        print("Currently using GPU {}".format(args.gpu_devices))
        torch.cuda.manual_seed_all(args.seed)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(args.dataset))
    dataset = data_manager.init_dataset(name=args.dataset, root=args.root)

    # Data augmentation
    spatial_transform_train = [
        ST.Scale((args.height, args.width), interpolation=3),
        ST.RandomHorizontalFlip(),
        ST.ToTensor(),
        ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ST.RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406])
    ]
    spatial_transform_train = ST.Compose(spatial_transform_train)

    temporal_transform_train = TT.TemporalRandomCrop(size=args.seq_len,
                                                     stride=args.sample_stride)
    #temporal_transform_train = TT.TemporalRandomCropPick(size=args.seq_len, stride=args.sample_stride)

    spatial_transform_test = ST.Compose([
        ST.Scale((args.height, args.width), interpolation=3),
        ST.ToTensor(),
        ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    temporal_transform_test = TT.TemporalBeginCrop(size=args.test_frames)

    pin_memory = True if use_gpu else False

    dataset_train = dataset.train
    if args.dataset == 'duke':
        dataset_train = dataset.train_dense
        print('process duke dataset')

    #sampler = RandomIdentitySampler(dataset_train, num_instances=args.num_instances)
    if args.dataset == 'lsvid':
        sampler = RandomIdentityCameraSampler(dataset_train,
                                              num_instances=args.num_instances,
                                              num_cam=dataset.num_camids)
    elif args.dataset == 'mars':
        sampler = RandomIdentityCameraSampler(dataset_train,
                                              num_instances=args.num_instances,
                                              num_cam=dataset.num_camids)
    trainloader = DataLoader(
        VideoDataset(dataset_train,
                     spatial_transform=spatial_transform_train,
                     temporal_transform=temporal_transform_train),
        sampler=sampler,
        batch_size=args.train_batch,
        num_workers=args.workers,
        pin_memory=pin_memory,
        drop_last=True,
    )
    '''
    for batch_idx, (vids, pids, camids, img_paths) in enumerate(trainloader):
        print(batch_idx, pids, camids, img_paths)
        break
    return
    '''
    dataset_query = dataset.query
    dataset_gallery = dataset.gallery
    if args.dataset == 'lsvid':
        dataset_query = dataset.val_query
        dataset_gallery = dataset.val_gallery
        print('process lsvid dataset')

    queryloader = DataLoader(VideoDataset(
        dataset_query,
        spatial_transform=spatial_transform_test,
        temporal_transform=temporal_transform_test),
                             batch_size=args.test_batch,
                             shuffle=False,
                             num_workers=args.workers,
                             pin_memory=pin_memory,
                             drop_last=False)

    galleryloader = DataLoader(VideoDataset(
        dataset_gallery,
        spatial_transform=spatial_transform_test,
        temporal_transform=temporal_transform_test),
                               batch_size=args.test_batch,
                               shuffle=False,
                               num_workers=args.workers,
                               pin_memory=pin_memory,
                               drop_last=False)

    print("Initializing model: {}".format(args.arch))
    model = models.init_model(name=args.arch,
                              use_gpu=use_gpu,
                              num_classes=dataset.num_train_pids,
                              loss={'xent', 'htri'},
                              vis=True)
    #print(model)
    if args.resume:
        print("Loading checkpoint from '{}'".format(args.resume))
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])

    criterion_xent = nn.CrossEntropyLoss()
    criterion_htri = TripletLoss(margin=args.margin,
                                 distance=args.distance,
                                 use_gpu=use_gpu)
    criterion_htri_c = TripletInterCamLoss(margin=args.margin,
                                           distance=args.distance,
                                           use_gpu=use_gpu)
    criterion_attn = MultiAttnSimLoss()
    #criterion_htri_c = TripletWeightedInterCamLoss(margin=args.margin, distance=args.distance, use_gpu=use_gpu, alpha=args.cam_alpha)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    scheduler = WarmupMultiStepLR(optimizer,
                                  milestones=args.stepsize,
                                  gamma=args.gamma,
                                  warmup_factor=1.0 / 10,
                                  warmup_iters=10,
                                  warmup_method="linear")
    start_epoch = args.start_epoch

    if use_gpu:
        model = nn.DataParallel(model).cuda()
        #model = model.cuda()
        #model = nn.parallel.DistributedDataParallel(model)

    start_time = time.time()
    train_time = 0
    best_rank1 = -np.inf
    best_epoch = 0
    print("==> Start training")

    for epoch in range(start_epoch, args.max_epoch):
        #print("Set sampler seed to {}".format(args.seed*epoch))
        #sampler.set_seed(args.seed*epoch)
        if args.resume and epoch + 1 <= args.resume_epoch:
            print(f"Skip epoch {epoch+1}")
            scheduler.step()
            continue

        start_train_time = time.time()
        train(epoch, model, criterion_xent, criterion_htri, criterion_htri_c,
              criterion_attn, optimizer, trainloader, use_gpu)
        train_time += round(time.time() - start_train_time)

        scheduler.step()

        if (epoch + 1) >= args.start_eval and (
                epoch + 1) % args.eval_step == 0 or epoch == 0:
            print("==> Test")
            with torch.no_grad():
                rank1 = test(model, queryloader, galleryloader, use_gpu)
            is_best = rank1 > best_rank1
            if is_best:
                best_rank1 = rank1
                best_epoch = epoch + 1

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            save_checkpoint(
                {
                    'state_dict': state_dict,
                    'rank1': rank1,
                    'epoch': epoch,
                }, is_best,
                osp.join(args.save_dir,
                         'checkpoint_ep' + str(epoch + 1) + '.pth.tar'))

    print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format(
        best_rank1, best_epoch))

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    train_time = str(datetime.timedelta(seconds=train_time))
    print(
        "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.".
        format(elapsed, train_time))
def main():
    runId = datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    cfg.OUTPUT_DIR = os.path.join(cfg.OUTPUT_DIR, runId)
    if not os.path.exists(cfg.OUTPUT_DIR):
        os.mkdir(cfg.OUTPUT_DIR)
    print(cfg.OUTPUT_DIR)
    torch.manual_seed(cfg.RANDOM_SEED)
    random.seed(cfg.RANDOM_SEED)
    np.random.seed(cfg.RANDOM_SEED)
    os.environ['CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID

    use_gpu = torch.cuda.is_available() and cfg.MODEL.DEVICE == "cuda"
    if not cfg.EVALUATE_ONLY:
        sys.stdout = Logger(osp.join(cfg.OUTPUT_DIR, 'log_train.txt'))
    else:
        sys.stdout = Logger(osp.join(cfg.OUTPUT_DIR, 'log_test.txt'))

    print("==========\nConfigs:{}\n==========".format(cfg))

    if use_gpu:
        print("Currently using GPU {}".format(cfg.MODEL.DEVICE_ID))
        cudnn.benchmark = True
        torch.cuda.manual_seed_all(cfg.RANDOM_SEED)
    else:
        print("Currently using CPU (GPU is highly recommended)")

    print("Initializing dataset {}".format(cfg.DATASETS.NAME))

    dataset = data_manager.init_dataset(root=cfg.DATASETS.ROOT_DIR,
                                        name=cfg.DATASETS.NAME)
    print("Initializing model: {}".format(cfg.MODEL.NAME))

    if cfg.MODEL.ARCH == 'video_baseline':
        torch.backends.cudnn.benchmark = False
        model = models.init_model(name=cfg.MODEL.ARCH,
                                  num_classes=625,
                                  pretrain_choice=cfg.MODEL.PRETRAIN_CHOICE,
                                  last_stride=cfg.MODEL.LAST_STRIDE,
                                  neck=cfg.MODEL.NECK,
                                  model_name=cfg.MODEL.NAME,
                                  neck_feat=cfg.TEST.NECK_FEAT,
                                  model_path=cfg.MODEL.PRETRAIN_PATH)

    print("Model size: {:.5f}M".format(
        sum(p.numel() for p in model.parameters()) / 1000000.0))

    transform_train = T.Compose([
        T.Resize(cfg.INPUT.SIZE_TRAIN),
        T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
        T.Pad(cfg.INPUT.PADDING),
        T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        T.RandomErasing(probability=cfg.INPUT.RE_PROB,
                        mean=cfg.INPUT.PIXEL_MEAN)
    ])
    transform_test = T.Compose([
        T.Resize(cfg.INPUT.SIZE_TEST),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    pin_memory = True if use_gpu else False

    cfg.DATALOADER.NUM_WORKERS = 0

    trainloader = DataLoader(VideoDataset(
        dataset.train,
        seq_len=cfg.DATASETS.SEQ_LEN,
        sample=cfg.DATASETS.TRAIN_SAMPLE_METHOD,
        transform=transform_train,
        dataset_name=cfg.DATASETS.NAME),
                             sampler=RandomIdentitySampler(
                                 dataset.train,
                                 num_instances=cfg.DATALOADER.NUM_INSTANCE),
                             batch_size=cfg.SOLVER.SEQS_PER_BATCH,
                             num_workers=cfg.DATALOADER.NUM_WORKERS,
                             pin_memory=pin_memory,
                             drop_last=True)

    queryloader = DataLoader(VideoDataset(
        dataset.query,
        seq_len=cfg.DATASETS.SEQ_LEN,
        sample=cfg.DATASETS.TEST_SAMPLE_METHOD,
        transform=transform_test,
        max_seq_len=cfg.DATASETS.TEST_MAX_SEQ_NUM,
        dataset_name=cfg.DATASETS.NAME),
                             batch_size=cfg.TEST.SEQS_PER_BATCH,
                             shuffle=False,
                             num_workers=cfg.DATALOADER.NUM_WORKERS,
                             pin_memory=pin_memory,
                             drop_last=False)

    galleryloader = DataLoader(
        VideoDataset(dataset.gallery,
                     seq_len=cfg.DATASETS.SEQ_LEN,
                     sample=cfg.DATASETS.TEST_SAMPLE_METHOD,
                     transform=transform_test,
                     max_seq_len=cfg.DATASETS.TEST_MAX_SEQ_NUM,
                     dataset_name=cfg.DATASETS.NAME),
        batch_size=cfg.TEST.SEQS_PER_BATCH,
        shuffle=False,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        pin_memory=pin_memory,
        drop_last=False,
    )

    if cfg.MODEL.SYN_BN:
        if use_gpu:
            model = nn.DataParallel(model)
        if cfg.SOLVER.FP_16:
            model = apex.parallel.convert_syncbn_model(model)
        model.cuda()

    start_time = time.time()
    xent = CrossEntropyLabelSmooth(num_classes=dataset.num_train_pids)
    tent = TripletLoss(cfg.SOLVER.MARGIN)

    optimizer = make_optimizer(cfg, model)

    scheduler = WarmupMultiStepLR(optimizer, cfg.SOLVER.STEPS,
                                  cfg.SOLVER.GAMMA, cfg.SOLVER.WARMUP_FACTOR,
                                  cfg.SOLVER.WARMUP_ITERS,
                                  cfg.SOLVER.WARMUP_METHOD)
    # metrics = test(model, queryloader, galleryloader, cfg.TEST.TEMPORAL_POOL_METHOD, use_gpu)
    no_rise = 0
    best_rank1 = 0
    start_epoch = 0
    for epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCHS):
        # if no_rise == 10:
        #     break
        scheduler.step()
        print("noriase:", no_rise)
        print("==> Epoch {}/{}".format(epoch + 1, cfg.SOLVER.MAX_EPOCHS))
        print("current lr:", scheduler.get_lr()[0])

        train(model, trainloader, xent, tent, optimizer, use_gpu)
        if cfg.SOLVER.EVAL_PERIOD > 0 and (
            (epoch + 1) % cfg.SOLVER.EVAL_PERIOD == 0 or
            (epoch + 1) == cfg.SOLVER.MAX_EPOCHS):
            print("==> Test")

            metrics = test(model, queryloader, galleryloader,
                           cfg.TEST.TEMPORAL_POOL_METHOD, use_gpu)
            rank1 = metrics[0]
            if rank1 > best_rank1:
                best_rank1 = rank1
                no_rise = 0
            else:
                no_rise += 1
                continue

            if use_gpu:
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            torch.save(
                state_dict,
                osp.join(
                    cfg.OUTPUT_DIR, "rank1_" + str(rank1) + '_checkpoint_ep' +
                    str(epoch + 1) + '.pth'))
            # best_p = osp.join(cfg.OUTPUT_DIR, "rank1_" + str(rank1) + '_checkpoint_ep' + str(epoch + 1) + '.pth')

    elapsed = round(time.time() - start_time)
    elapsed = str(datetime.timedelta(seconds=elapsed))
    print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))