def main(): start_time = time.monotonic() # init distributed training args, cfg = parge_config() dist = init_dist(cfg) synchronize() # init logging file logger = Logger(cfg.work_dir / "log_test.txt") sys.stdout = logger print("==========\nArgs:{}\n==========".format(args)) log_config_to_file(cfg) # build model model = build_gan_model(cfg, only_generator=True)['G'] model.cuda() if dist: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[cfg.gpu], output_device=cfg.gpu, find_unused_parameters=True, ) elif cfg.total_gpus > 1: model = torch.nn.DataParallel(model) # load checkpoint state_dict = load_checkpoint(args.resume) copy_state_dict(state_dict["state_dict"], model) # load data_loader test_loader, _ = build_val_dataloader(cfg, for_clustering=True, all_datasets=True) # print(len(test_loader[0])) # return # start testing infer_gan( cfg, model, test_loader[0], # source dataset # dataset_name=list(cfg.TRAIN.datasets.keys())[0] dataset_name=cfg.TRAIN.data_names[0]) # print time end_time = time.monotonic() print("Total running time: ", timedelta(seconds=end_time - start_time))
def main(): start_time = time.monotonic() # init distributed training args, cfg = parge_config() dist = init_dist(cfg) synchronize() # init logging file logger = Logger(cfg.work_dir / "log.txt", debug=False) sys.stdout = logger print("==========\nArgs:{}\n==========".format(args)) log_config_to_file(cfg) # build train loader train_loader, train_sets = build_train_dataloader(cfg, joint=False) # build model model = build_model(cfg, 0, init=cfg.MODEL.source_pretrained) model.cuda() if dist: ddp_cfg = { "device_ids": [cfg.gpu], "output_device": cfg.gpu, "find_unused_parameters": True, } model = DistributedDataParallel(model, **ddp_cfg) elif cfg.total_gpus > 1: model = DataParallel(model) # build optimizer optimizer = build_optimizer([model], **cfg.TRAIN.OPTIM) # build lr_scheduler if cfg.TRAIN.SCHEDULER.lr_scheduler is not None: lr_scheduler = build_lr_scheduler(optimizer, **cfg.TRAIN.SCHEDULER) else: lr_scheduler = None # build loss functions num_memory = 0 for idx, set in enumerate(train_sets): if idx in cfg.TRAIN.unsup_dataset_indexes: # instance-level memory for unlabeled data num_memory += len(set) else: # class-level memory for labeled data num_memory += set.num_pids if isinstance(model, (DataParallel, DistributedDataParallel)): num_features = model.module.num_features else: num_features = model.num_features criterions = build_loss( cfg.TRAIN.LOSS, num_features=num_features, num_memory=num_memory, cuda=True, ) # init memory loaders, datasets = build_val_dataloader( cfg, for_clustering=True, all_datasets=True ) memory_features = [] for idx, (loader, dataset) in enumerate(zip(loaders, datasets)): features = extract_features( model, loader, dataset, with_path=False, prefix="Extract: ", ) assert features.size(0) == len(dataset) if idx in cfg.TRAIN.unsup_dataset_indexes: # init memory for unlabeled data with instance features memory_features.append(features) else: # init memory for labeled data with class centers centers_dict = collections.defaultdict(list) for i, (_, pid, _) in enumerate(dataset): centers_dict[pid].append(features[i].unsqueeze(0)) centers = [ torch.cat(centers_dict[pid], 0).mean(0) for pid in sorted(centers_dict.keys()) ] memory_features.append(torch.stack(centers, 0)) del loaders, datasets memory_features = torch.cat(memory_features) criterions["hybrid_memory"]._update_feature(memory_features) # build runner runner = SpCLRunner( cfg, model, optimizer, criterions, train_loader, train_sets=train_sets, lr_scheduler=lr_scheduler, meter_formats={"Time": ":.3f",}, reset_optim=False, ) # resume if args.resume_from: runner.resume(args.resume_from) # start training runner.run() # load the best model runner.resume(cfg.work_dir / "model_best.pth") # final testing test_loaders, queries, galleries = build_test_dataloader(cfg) for i, (loader, query, gallery) in enumerate(zip(test_loaders, queries, galleries)): cmc, mAP = test_reid( cfg, model, loader, query, gallery, dataset_name=cfg.TEST.datasets[i] ) # print time end_time = time.monotonic() print("Total running time: ", timedelta(seconds=end_time - start_time))
def main(): start_time = time.monotonic() # init distributed training args, cfg = parge_config() dist = init_dist(cfg) set_random_seed(cfg.TRAIN.seed, cfg.TRAIN.deterministic) synchronize() # init logging file logger = Logger(cfg.work_dir / 'log.txt', debug=False) sys.stdout = logger print("==========\nArgs:{}\n==========".format(args)) log_config_to_file(cfg) # build train loader train_loader, _ = build_train_dataloader(cfg, joint=False) # build model model = build_gan_model(cfg) for key in model.keys(): model[key].cuda() if dist: ddp_cfg = { "device_ids": [cfg.gpu], "output_device": cfg.gpu, "find_unused_parameters": True, } for key in model.keys(): model[key] = torch.nn.parallel.DistributedDataParallel( model[key], **ddp_cfg) elif cfg.total_gpus > 1: for key in model.keys(): model[key] = torch.nn.DataParallel(model[key]) # build optimizer optimizer = {} optimizer['G'] = build_optimizer([model['G_A'], model['G_B']], **cfg.TRAIN.OPTIM) optimizer['D'] = build_optimizer([model['D_A'], model['D_B']], **cfg.TRAIN.OPTIM) # build lr_scheduler if cfg.TRAIN.SCHEDULER.lr_scheduler is not None: lr_scheduler = [build_lr_scheduler(optimizer[key], **cfg.TRAIN.SCHEDULER) \ for key in optimizer.keys()] else: lr_scheduler = None # build loss functions criterions = build_loss(cfg.TRAIN.LOSS, cuda=True) # build runner runner = GANBaseRunner(cfg, model, optimizer, criterions, train_loader, lr_scheduler=lr_scheduler, meter_formats={"Time": ":.3f"}) # resume if args.resume_from: runner.resume(args.resume_from) # start training runner.run() # load the latest model # runner.resume(cfg.work_dir) # final inference test_loader, _ = build_val_dataloader(cfg, for_clustering=True, all_datasets=True) # source to target infer_gan(cfg, model['G_A'], test_loader[0], dataset_name=list(cfg.TRAIN.datasets.keys())[0]) # target to source infer_gan(cfg, model['G_B'], test_loader[1], dataset_name=list(cfg.TRAIN.datasets.keys())[1]) # print time end_time = time.monotonic() print("Total running time: ", timedelta(seconds=end_time - start_time))
def update_labels(self): sep = "*************************" print( f"\n{sep} Start updating pseudo labels on epoch {self._epoch} {sep}\n" ) # generate pseudo labels pseudo_labels, label_centers = self.label_generator( self._epoch, print_freq=self.print_freq) # update train loader self.train_loader, self.train_sets = build_train_dataloader( self.cfg, pseudo_labels, self.train_sets, self._epoch, ) # re-construct memory num_memory = 0 for idx, set in enumerate(self.train_sets): if idx in self.cfg.TRAIN.unsup_dataset_indexes: # cluster-level memory for unlabeled data num_memory += self.cfg.TRAIN.PSEUDO_LABELS.cluster_num[ self.cfg.TRAIN.unsup_dataset_indexes.index(idx)] else: # class-level memory for labeled data num_memory += set.num_pids if isinstance(self.model, (DataParallel, DistributedDataParallel)): num_features = self.model.module.num_features else: num_features = self.model.num_features self.criterions = build_loss( self.cfg.TRAIN.LOSS, num_features=num_features, num_memory=num_memory, cuda=True, ) # initialize memory loaders, datasets = build_val_dataloader(self.cfg, for_clustering=True, all_datasets=True) memory_features = [] for idx, (loader, dataset) in enumerate(zip(loaders, datasets)): if idx in cfg.TRAIN.unsup_dataset_indexes: memory_features.append( label_centers[cfg.TRAIN.unsup_dataset_indexes.index(idx)]) else: features = extract_features( self.model, loader, dataset, with_path=False, prefix="Extract: ", ) assert features.size(0) == len(dataset) centers_dict = collections.defaultdict(list) for i, (_, pid, _) in enumerate(dataset): centers_dict[pid].append(features[i].unsqueeze(0)) centers = [ torch.cat(centers_dict[pid], 0).mean(0) for pid in sorted(centers_dict.keys()) ] memory_features.append(torch.stack(centers, 0)) del loaders, datasets memory_features = torch.cat(memory_features) self.criterions["hybrid_memory"]._update_feature(memory_features) memory_labels = [] start_pid = 0 for idx, dataset in enumerate(self.train_sets): num_pids = dataset.num_pids memory_labels.append(torch.arange(start_pid, start_pid + num_pids)) start_pid += num_pids memory_labels = torch.cat(memory_labels).view(-1) self.criterions["hybrid_memory"]._update_label(memory_labels) print(f"\n{sep} Finished updating pseudo label {sep}\n")