def main(): torch.manual_seed(args.seed) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices use_gpu = torch.cuda.is_available() if args.use_cpu: use_gpu = False # set a learning rate if args.lr_factor == -1: args.lr_factor = random() args.lr = args.lr_factor * 10**-args.lr_base args.lr *= len(args.gpu_devices.split(',')) print(f"Choose learning rate {args.lr}") sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'), mode='a') print("==========\nArgs:{}\n==========".format(args)) #assert torch.distributed.is_available() #print("Initializing DDP by nccl-tcp({}) rank({}) world_size({})".format(args.init_method, args.rank, args.world_size)) #dist.init_process_group(backend='nccl', init_method=args.init_method, rank=args.rank, world_size=args.world_size) if use_gpu: print("Currently using GPU {}".format(args.gpu_devices)) torch.cuda.manual_seed_all(args.seed) else: print("Currently using CPU (GPU is highly recommended)") print("Initializing dataset {}".format(args.dataset)) dataset = data_manager.init_dataset(name=args.dataset, root=args.root) # Data augmentation spatial_transform_train = [ ST.Scale((args.height, args.width), interpolation=3), ST.RandomHorizontalFlip(), ST.ToTensor(), ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ST.RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406]) ] spatial_transform_train = ST.Compose(spatial_transform_train) temporal_transform_train = TT.TemporalRandomCrop(size=args.seq_len, stride=args.sample_stride) #temporal_transform_train = TT.TemporalRandomCropPick(size=args.seq_len, stride=args.sample_stride) spatial_transform_test = ST.Compose([ ST.Scale((args.height, args.width), interpolation=3), ST.ToTensor(), ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) temporal_transform_test = TT.TemporalBeginCrop(size=args.test_frames) pin_memory = True if use_gpu else False dataset_train = dataset.train if args.dataset == 'duke': dataset_train = dataset.train_dense print('process duke dataset') #sampler = RandomIdentitySampler(dataset_train, num_instances=args.num_instances) if args.dataset == 'lsvid': sampler = RandomIdentityCameraSampler(dataset_train, num_instances=args.num_instances, num_cam=dataset.num_camids) elif args.dataset == 'mars': sampler = RandomIdentityCameraSampler(dataset_train, num_instances=args.num_instances, num_cam=dataset.num_camids) trainloader = DataLoader( VideoDataset(dataset_train, spatial_transform=spatial_transform_train, temporal_transform=temporal_transform_train), sampler=sampler, batch_size=args.train_batch, num_workers=args.workers, pin_memory=pin_memory, drop_last=True, ) ''' for batch_idx, (vids, pids, camids, img_paths) in enumerate(trainloader): print(batch_idx, pids, camids, img_paths) break return ''' dataset_query = dataset.query dataset_gallery = dataset.gallery if args.dataset == 'lsvid': dataset_query = dataset.val_query dataset_gallery = dataset.val_gallery print('process lsvid dataset') queryloader = DataLoader(VideoDataset( dataset_query, spatial_transform=spatial_transform_test, temporal_transform=temporal_transform_test), batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=pin_memory, drop_last=False) galleryloader = DataLoader(VideoDataset( dataset_gallery, spatial_transform=spatial_transform_test, temporal_transform=temporal_transform_test), batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=pin_memory, drop_last=False) print("Initializing model: {}".format(args.arch)) model = models.init_model(name=args.arch, use_gpu=use_gpu, num_classes=dataset.num_train_pids, loss={'xent', 'htri'}, vis=True) #print(model) if args.resume: print("Loading checkpoint from '{}'".format(args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) criterion_xent = nn.CrossEntropyLoss() criterion_htri = TripletLoss(margin=args.margin, distance=args.distance, use_gpu=use_gpu) criterion_htri_c = TripletInterCamLoss(margin=args.margin, distance=args.distance, use_gpu=use_gpu) criterion_attn = MultiAttnSimLoss() #criterion_htri_c = TripletWeightedInterCamLoss(margin=args.margin, distance=args.distance, use_gpu=use_gpu, alpha=args.cam_alpha) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) scheduler = WarmupMultiStepLR(optimizer, milestones=args.stepsize, gamma=args.gamma, warmup_factor=1.0 / 10, warmup_iters=10, warmup_method="linear") start_epoch = args.start_epoch if use_gpu: model = nn.DataParallel(model).cuda() #model = model.cuda() #model = nn.parallel.DistributedDataParallel(model) start_time = time.time() train_time = 0 best_rank1 = -np.inf best_epoch = 0 print("==> Start training") for epoch in range(start_epoch, args.max_epoch): #print("Set sampler seed to {}".format(args.seed*epoch)) #sampler.set_seed(args.seed*epoch) if args.resume and epoch + 1 <= args.resume_epoch: print(f"Skip epoch {epoch+1}") scheduler.step() continue start_train_time = time.time() train(epoch, model, criterion_xent, criterion_htri, criterion_htri_c, criterion_attn, optimizer, trainloader, use_gpu) train_time += round(time.time() - start_train_time) scheduler.step() if (epoch + 1) >= args.start_eval and ( epoch + 1) % args.eval_step == 0 or epoch == 0: print("==> Test") with torch.no_grad(): rank1 = test(model, queryloader, galleryloader, use_gpu) is_best = rank1 > best_rank1 if is_best: best_rank1 = rank1 best_epoch = epoch + 1 if use_gpu: state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint( { 'state_dict': state_dict, 'rank1': rank1, 'epoch': epoch, }, is_best, osp.join(args.save_dir, 'checkpoint_ep' + str(epoch + 1) + '.pth.tar')) print("==> Best Rank-1 {:.1%}, achieved at epoch {}".format( best_rank1, best_epoch)) elapsed = round(time.time() - start_time) elapsed = str(datetime.timedelta(seconds=elapsed)) train_time = str(datetime.timedelta(seconds=train_time)) print( "Finished. Total elapsed time (h:m:s): {}. Training time (h:m:s): {}.". format(elapsed, train_time))
def main(): torch.manual_seed(args.seed) os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices use_gpu = torch.cuda.is_available() if args.use_cpu: use_gpu = False # set a learning rate #if args.lr_factor == -1: # args.lr_factor = random() #args.lr = args.lr_factor * 10**-args.lr_base #print(f"Choose learning rate {args.lr}") sys.stdout = Logger(osp.join(args.save_dir, 'log_train.txt'), mode='a') print("==========\nArgs:{}\n==========".format(args)) #assert torch.distributed.is_available() #print("Initializing DDP by nccl-tcp({}) rank({}) world_size({})".format(args.init_method, args.rank, args.world_size)) #dist.init_process_group(backend='nccl', init_method=args.init_method, rank=args.rank, world_size=args.world_size) if use_gpu: print("Currently using GPU {}".format(args.gpu_devices)) torch.cuda.manual_seed_all(args.seed) else: print("Currently using CPU (GPU is highly recommended)") print("Initializing dataset {}".format(args.dataset)) dataset = data_manager.init_dataset(name=args.dataset, root=args.root) # Data augmentation spatial_transform_train = [ ST.Scale((args.height, args.width), interpolation=3), ST.RandomHorizontalFlip(), ST.ToTensor(), ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ] spatial_transform_train = ST.Compose(spatial_transform_train) temporal_transform_train = TT.TemporalRandomCrop(size=args.seq_len, stride=args.sample_stride) #temporal_transform_train = TT.TemporalRandomCropPick(size=args.seq_len, stride=args.sample_stride) spatial_transform_test = ST.Compose([ ST.Scale((args.height, args.width), interpolation=3), ST.ToTensor(), ST.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) temporal_transform_test = TT.TemporalBeginCrop(size=args.test_frames) pin_memory = True if use_gpu else False dataset_train = dataset.train if args.dataset == 'duke': dataset_train = dataset.train_dense print('process duke dataset') #sampler = RandomIdentitySampler(dataset_train, num_instances=args.num_instances) if args.dataset == 'lsvid': sampler = RandomIdentityCameraSampler(dataset_train, num_instances=args.num_instances, num_cam=15) elif args.dataset == 'mars': sampler = RandomIdentityCameraSampler(dataset_train, num_instances=args.num_instances, num_cam=6) trainloader = DataLoader( VideoDataset(dataset_train, spatial_transform=spatial_transform_train, temporal_transform=temporal_transform_train), sampler=sampler, batch_size=args.train_batch, num_workers=args.workers, pin_memory=pin_memory, drop_last=True, ) ''' for batch_idx, (vids, pids, camids, img_paths) in enumerate(trainloader): print(batch_idx, pids, camids, img_paths) break return ''' dataset_query = dataset.query dataset_gallery = dataset.gallery if args.dataset == 'lsvid': dataset_query = dataset.val_query dataset_gallery = dataset.val_gallery print('process lsvid dataset') queryloader = DataLoader( VideoDataset(dataset_query, spatial_transform=spatial_transform_test, temporal_transform=temporal_transform_test), batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=pin_memory, drop_last=False ) galleryloader = DataLoader( VideoDataset(dataset_gallery, spatial_transform=spatial_transform_test, temporal_transform=temporal_transform_test), batch_size=args.test_batch, shuffle=False, num_workers=args.workers, pin_memory=pin_memory, drop_last=False ) print("Initializing model: {}".format(args.arch)) model = models.init_model(name=args.arch, use_gpu=use_gpu, num_classes=dataset.num_train_pids, loss={'xent', 'htri'}, transformer_num_heads=args.transformer_num_heads, transformer_num_layers=args.transformer_num_layers, attention_flatness=True) #print(model) if args.resume: print("Loading checkpoint from '{}'".format(args.resume)) checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) criterion_xent = nn.CrossEntropyLoss() criterion_flat = FlatnessLoss(reduction='batchmean', use_gpu=use_gpu) criterion_htri_c = TripletInterCamLoss(margin=args.margin, distance=args.distance, use_gpu=use_gpu) #criterion_htri_c = TripletWeightedInterCamLoss(margin=args.margin, distance=args.distance, use_gpu=use_gpu, alpha=args.cam_alpha) #optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) linear_scaled_lr = args.lr * args.train_batch * len(args.gpu_devices.split(',')) / 512.0 args.lr = linear_scaled_lr