def train(config_file, resume=False, iteration=10, STEP=4, **kwargs): """ Parameter --------- resume : bool If true, continue the training and append logs to the previous log. iteration : int number of loops to test Random Datasets. STEP : int Number of steps to train the discriminator per batch """ cfg.merge_from_file(config_file) if kwargs: opts = [] for k, v in kwargs.items(): opts.append(k) opts.append(v) cfg.merge_from_list(opts) cfg.freeze() # [PersonReID_Dataset_Downloader('./datasets', name) for name in cfg.DATASETS.NAMES] output_dir = cfg.OUTPUT_DIR if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) logger = make_logger("Reid_Baseline", output_dir, 'log', resume) if not resume: logger.info("Using {} GPUS".format(1)) logger.info("Loaded configuration file {}".format(config_file)) logger.info("Running with config:\n{}".format(cfg)) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD output_dir = cfg.OUTPUT_DIR device = torch.device(cfg.DEVICE) epochs = cfg.SOLVER.MAX_EPOCHS sources = cfg.DATASETS.SOURCE target = cfg.DATASETS.TARGET pooling = cfg.MODEL.POOL last_stride = cfg.MODEL.LAST_STRIDE # tf_board_path = os.path.join(output_dir, 'tf_runs') # if os.path.exists(tf_board_path): # shutil.rmtree(tf_board_path) # writer = SummaryWriter(tf_board_path) gan_d_param = cfg.MODEL.D_PARAM gan_g_param = cfg.MODEL.G_PARAM class_param = cfg.MODEL.CLASS_PARAM """Set up""" train_loader, _, _, num_classes = data_loader(cfg, cfg.DATASETS.SOURCE, merge=cfg.DATASETS.MERGE) num_classes_train = [ data_loader(cfg, [source], merge=False)[3] for source in cfg.DATASETS.SOURCE ] # based on input datasets bias = (max(num_classes_train)) / np.array(num_classes_train) bias = bias / bias.sum() * 5 discriminator_loss = LabelSmoothingLoss(len(sources), weights=bias, smoothing=0.1) minus_generator_loss = LabelSmoothingLoss(len(sources), weights=bias, smoothing=0.) classification_loss = LabelSmoothingLoss(num_classes, smoothing=0.1) from loss.triplet_loss import TripletLoss triplet = TripletLoss(cfg.SOLVER.MARGIN) triplet_loss = lambda feat, labels: triplet(feat, labels)[0] module = getattr(generalizers, cfg.MODEL.NAME) D = getattr(module, 'Generalizer_D')(len(sources)) G = getattr(module, 'Generalizer_G')(num_classes, last_stride, pooling) if resume: checkpoints = get_last_stats(output_dir) D.load_state_dict(torch.load(checkpoints[str(type(D))])) G.load_state_dict(torch.load(checkpoints[str(type(G))])) if device: # must be done before the optimizer generation D.to(device) G.to(device) discriminator_optimizer = Adam(D.parameters(), lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) generator_optimizer = Adam(G.parameters(), lr=cfg.SOLVER.BASE_LR, weight_decay=cfg.SOLVER.WEIGHT_DECAY) discriminator_scheduler = make_scheduler(cfg, discriminator_optimizer) generator_scheduler = make_scheduler(cfg, generator_optimizer) base_epo = 0 if resume: discriminator_optimizer.load_state_dict( torch.load(checkpoints['D_opt'])) generator_optimizer.load_state_dict(torch.load(checkpoints['G_opt'])) discriminator_scheduler.load_state_dict( torch.load(checkpoints['D_sch'])) generator_scheduler.load_state_dict(torch.load(checkpoints['G_sch'])) base_epo = checkpoints['epo'] # Modify the labels: # RULE: # according to the order of names in cfg.DATASETS.NAMES, add base numebr since = time.time() if not resume: logger.info("Start training") batch_count = 0 STEP = 4 Best_R1s = [0, 0, 0, 0] Benchmark = [69.6, 43.7, 59.4, 78.2] for epoch in range(epochs): # anneal = sigmoid(annealing_base + annealing_factor*(epoch+base_epo)) anneal = max(1 - (1 / 80 * epoch), 0) count = 0 running_g_loss = 0. running_source_loss = 0. running_class_acc = 0. running_acc_source = 0. running_class_loss = 0. reset() for data in tqdm(train_loader, desc='Iteration', leave=False): # NOTE: zip ensured the shortest dataset dominates the iteration D.train() G.train() images, labels, domains = data if device: D.to(device) G.to(device) images, labels, domains = images.to(device), labels.to( device), domains.to(device) """Start Training D""" feature_vec, scores, gan_vec = G(images) for param in G.parameters(): param.requires_grad = False for param in D.parameters(): param.requires_grad = True for _ in range(STEP): discriminator_optimizer.zero_grad() pred_domain = D( [v.detach() for v in gan_vec] if isinstance(gan_vec, list) else gan_vec.detach()) # NOTE: Feat output! Not Probability! d_losses, accs = discriminator_loss(pred_domain, domains, compute_acc=True) d_source_loss = d_losses.mean() d_source_acc = accs.float().mean().item() d_loss = d_source_loss w_d_loss = anneal * d_loss * gan_d_param w_d_loss.backward() discriminator_optimizer.step() """Start Training G""" for param in D.parameters(): param.requires_grad = False for param in G.parameters(): param.requires_grad = True generator_optimizer.zero_grad() g_loss = -1. * minus_generator_loss(D(gan_vec), domains).mean() class_loss = classification_loss(scores, labels).mean() tri_loss = triplet_loss(feature_vec, labels) class_loss = class_loss * cfg.SOLVER.LAMBDA1 + tri_loss * cfg.SOLVER.LAMBDA2 w_regularized_g_loss = anneal * gan_g_param * g_loss + class_param * class_loss w_regularized_g_loss.backward() generator_optimizer.step() """Stop training""" running_g_loss += g_loss.item() running_source_loss += d_source_loss.item() running_acc_source += d_source_acc # TODO: assume all batches are the same size running_class_loss += class_loss.item() class_acc = (scores.max(1)[1] == labels).float().mean().item() running_class_acc += class_acc # writer.add_scalar('D_loss', d_source_loss.item(), batch_count) # writer.add_scalar('D_acc', d_source_acc, batch_count) # writer.add_scalar('G_loss', g_loss.item(), batch_count) # writer.add_scalar('Class_loss', class_loss.item(), batch_count) # writer.add_scalar('Class_acc', class_acc, batch_count) torch.cuda.empty_cache() count = count + 1 batch_count += 1 # if count == 10:break logger.info( "Epoch[{}] Iteration[{}] Loss: [G] {:.3f} [D] {:.3f} [Class] {:.3f}, Acc: [Class] {:.3f} [D] {:.3f}, Base Lr: {:.2e}" .format(epoch + base_epo + 1, count, running_g_loss / count, running_source_loss / count, running_class_loss / count, running_class_acc / count, running_acc_source / count, generator_scheduler.get_lr()[0])) generator_scheduler.step() discriminator_scheduler.step() if (epoch + base_epo + 1) % checkpoint_period == 0: G.cpu() G.save(output_dir, epoch + base_epo + 1) D.cpu() D.save(output_dir, epoch + base_epo + 1) torch.save( generator_optimizer.state_dict(), os.path.join(output_dir, 'G_opt_epo' + str(epoch + base_epo + 1) + '.pth')) torch.save( discriminator_optimizer.state_dict(), os.path.join(output_dir, 'D_opt_epo' + str(epoch + base_epo + 1) + '.pth')) torch.save( generator_scheduler.state_dict(), os.path.join(output_dir, 'G_sch_epo' + str(epoch + base_epo + 1) + '.pth')) torch.save( discriminator_scheduler.state_dict(), os.path.join(output_dir, 'D_sch_epo' + str(epoch + base_epo + 1) + '.pth')) # Validation if (epoch + base_epo + 1) % eval_period == 0: # Validation on Target Dataset for target in cfg.DATASETS.TARGET: mAPs = [] cmcs = [] for i in range(iteration): set_seeds(i) _, val_loader, num_query, _ = data_loader(cfg, (target, ), merge=False, verbose=False) all_feats = [] all_pids = [] all_camids = [] since = time.time() for data in tqdm(val_loader, desc='Feature Extraction', leave=False): G.eval() with torch.no_grad(): images, pids, camids = data if device: G.to(device) images = images.to(device) feats = G(images) feats /= feats.norm(dim=-1, keepdim=True) all_feats.append(feats) all_pids.extend(np.asarray(pids)) all_camids.extend(np.asarray(camids)) cmc, mAP = evaluation(all_feats, all_pids, all_camids, num_query) mAPs.append(mAP) cmcs.append(cmc) mAP = np.mean(np.array(mAPs)) cmc = np.mean(np.array(cmcs), axis=0) mAP_std = np.std(np.array(mAPs)) cmc_std = np.std(np.array(cmcs), axis=0) logger.info("Validation Results: {} - Epoch: {}".format( target, epoch + 1 + base_epo)) logger.info("mAP: {:.1%} (std: {:.3%})".format(mAP, mAP_std)) for r in [1, 5, 10]: logger.info( "CMC curve, Rank-{:<3}:{:.1%} (std: {:.3%})".format( r, cmc[r - 1], cmc_std[r - 1])) # Record Best if (epoch + base_epo + 1) > 60 and ((epoch + base_epo + 1) % 5 == 1 or (epoch + base_epo + 1) % 5 == 2): # Validation on Target Dataset R1s = [] for target in cfg.DATASETS.TARGET: mAPs = [] cmcs = [] for i in range(iteration): set_seeds(i) _, val_loader, num_query, _ = data_loader(cfg, (target, ), merge=False, verbose=False) all_feats = [] all_pids = [] all_camids = [] since = time.time() for data in tqdm(val_loader, desc='Feature Extraction', leave=False): G.eval() with torch.no_grad(): images, pids, camids = data if device: G.to(device) images = images.to(device) feats = G(images) feats /= feats.norm(dim=-1, keepdim=True) all_feats.append(feats) all_pids.extend(np.asarray(pids)) all_camids.extend(np.asarray(camids)) cmc, mAP = evaluation(all_feats, all_pids, all_camids, num_query) mAPs.append(mAP) cmcs.append(cmc) mAP = np.mean(np.array(mAPs)) cmc = np.mean(np.array(cmcs), axis=0) R1 = cmc[0] R1s.append(R1) if (np.array(R1s) > np.array(Best_R1s)).all(): logger.info("Best checkpoint at {}: {}".format( str(epoch + base_epo + 1), ', '.join([str(s) for s in R1s]))) Best_R1s = R1s G.cpu() G.save(output_dir, -1) D.cpu() D.save(output_dir, -1) torch.save( generator_optimizer.state_dict(), os.path.join(output_dir, 'G_opt_epo' + str(-1) + '.pth')) torch.save( discriminator_optimizer.state_dict(), os.path.join(output_dir, 'D_opt_epo' + str(-1) + '.pth')) torch.save( generator_scheduler.state_dict(), os.path.join(output_dir, 'G_sch_epo' + str(-1) + '.pth')) torch.save( discriminator_scheduler.state_dict(), os.path.join(output_dir, 'D_sch_epo' + str(-1) + '.pth')) else: logger.info("Rank 1 results: {}".format(', '.join( [str(s) for s in R1s]))) time_elapsed = time.time() - since logger.info('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) logger.info('-' * 10)
def trace(config_file, model_type="generalizer", **kwargs): if model_type == "normal": from config.default_multi_domain import _C as cfg elif model_type == "generalizer": from config.default_multi_domain import _C as cfg else: raise ValueError("Model type can only be normal or generalizer.") cfg.merge_from_file(config_file) if kwargs: opts = [] for k, v in kwargs.items(): opts.append(k) opts.append(v) cfg.merge_from_list(opts) cfg.freeze() # PersonReID_Dataset_Downloader('./datasets',cfg.DATASETS.NAMES) train_loader, _, _, num_classes = data_loader(cfg, cfg.DATASETS.SOURCE, merge=cfg.DATASETS.MERGE) device = torch.device(cfg.DEVICE) if model_type == "generalizer": module = getattr(generalizers, cfg.MODEL.NAME) model = getattr(module, 'Generalizer_G')(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.POOL) checkpoints = get_last_stats(cfg.OUTPUT_DIR) model_dict = torch.load(checkpoints[str(type(model))]) model.load_state_dict(model_dict) elif model_type == "normal": model = getattr(models, cfg.MODEL.NAME)(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.POOL) checkpoints = get_last_stats(cfg.OUTPUT_DIR, [cfg.MODEL.NAME]) model_dict = torch.load(checkpoints[cfg.MODEL.NAME]) model.load_state_dict(model_dict) model = model.eval() input_names = ['input'] output_names = ['output'] batch = 1 images = torch.randn(batch, 3, 256, 128, requires_grad=True) if device: model.to(device) images = images.to(device) torch.onnx.export( model, images, 'test.onnx', verbose=True, input_names=input_names, output_names=output_names ) #, dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}})
def test_random_datasets(config_file, iteration=10, model_type="generalizer", **kwargs): if model_type == "normal": from config.default_multi_domain import _C as cfg elif model_type == "generalizer": from config.default_multi_domain import _C as cfg else: raise ValueError("Model type can only be normal or generalizer.") cfg.merge_from_file(config_file) if kwargs: opts = [] for k, v in kwargs.items(): opts.append(k) opts.append(v) cfg.merge_from_list(opts) cfg.freeze() # PersonReID_Dataset_Downloader('./datasets',cfg.DATASETS.NAMES) _, _, _, num_classes = data_loader(cfg, cfg.DATASETS.SOURCE, merge=cfg.DATASETS.MERGE) re_ranking = cfg.RE_RANKING device = torch.device(cfg.DEVICE) if model_type == "generalizer": module = getattr(generalizers, cfg.MODEL.NAME) model = getattr(module, 'Generalizer_G')(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.POOL) checkpoints = get_last_stats(cfg.OUTPUT_DIR) model_dict = torch.load(checkpoints[str(type(model))]) model.load_state_dict(model_dict) elif model_type == "normal": model = getattr(models, cfg.MODEL.NAME)(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.POOL) checkpoints = get_last_stats(cfg.OUTPUT_DIR, [cfg.MODEL.NAME]) model_dict = torch.load(checkpoints[cfg.MODEL.NAME]) model.load_state_dict(model_dict) model = model.eval() if not re_ranking: logger = make_logger("Reid_Baseline", cfg.OUTPUT_DIR, 'epo' + str(checkpoints['epo'])) logger.info("Test Results:") else: logger = make_logger("Reid_Baseline", cfg.OUTPUT_DIR, 'epo' + str(checkpoints['epo']) + '_re-ranking') logger.info("Re-Ranking Test Results:") for test_dataset in cfg.DATASETS.TARGET: mAPs = [] cmcs = [] for i in range(iteration): set_seeds(i) _, val_loader, num_query, _ = data_loader(cfg, (test_dataset, ), merge=False) all_feats = [] all_pids = [] all_camids = [] since = time.time() for data in tqdm(val_loader, desc='Feature Extraction', leave=False): model.eval() with torch.no_grad(): images, pids, camids = data if device: model.to(device) images = images.to(device) feats = model(images) feats /= feats.norm(dim=-1, keepdim=True) all_feats.append(feats) all_pids.extend(np.asarray(pids)) all_camids.extend(np.asarray(camids)) cmc, mAP = evaluation(all_feats, all_pids, all_camids, num_query, re_ranking) mAPs.append(mAP) cmcs.append(cmc) mAP = np.mean(np.array(mAPs)) cmc = np.mean(np.array(cmcs), axis=0) mAP_std = np.std(np.array(mAPs)) cmc_std = np.std(np.array(cmcs), axis=0) logger.info("mAP: {:.1%} (std: {:.3%})".format(mAP, mAP_std)) for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%} (std: {:.3%})".format( r, cmc[r - 1], cmc_std[r - 1])) test_time = time.time() - since logger.info('Testing complete in {:.0f}m {:.0f}s'.format( test_time // 60, test_time % 60))
def visualize(config_file, model_type="generalizer", mode="img", num_batch=2, metric='euclidean', vis_options=['train'], perplexity=30, save_dir='./', **kwargs): """ Visualize the feature vector space with options. CAUTION: TSNE is sensitive to its paramters and all potential clusters. That is why I use unique_by_pid function. Do not use exact mode for tsne. See https://distill.pub/2016/misread-tsne/ for details. Parameter --------- config_file : str=yacs.CfgNode model_type : str generalizer or normal models mode : str to plot dots/images/etc on tsne visualizatioin num_batch : int number of batches from data used to visualize metric : str distance metric for tsne. cosine or euclidean vis_options : List[str] to visualize on train/test data perplexity : int a critical parameter for tsne. Try multiple to see effects. Return --------- save a png picture to save_dir """ if mode not in ["img", "dot"]: raise ValueError("Mode can only be img or dot.") if model_type == "normal": from config.default_multi_domain import _C as cfg elif model_type == "generalizer": from config.default_multi_domain import _C as cfg else: raise ValueError("Model type can only be normal or generalizer.") cfg.merge_from_file(config_file) if kwargs: opts = [] for k, v in kwargs.items(): opts.append(k) opts.append(v) cfg.merge_from_list(opts) cfg.freeze() def recover_image(image): # Final format is still Tensor. image = image.permute(0, 2, 3, 1) * torch.Tensor( cfg.INPUT.PIXEL_STD) + torch.Tensor(cfg.INPUT.PIXEL_MEAN) return image # PersonReID_Dataset_Downloader('./datasets',cfg.DATASETS.NAMES) _, _, _, num_classes = data_loader(cfg, cfg.DATASETS.SOURCE, merge=cfg.DATASETS.MERGE, verbose=False) device = torch.device(cfg.DEVICE) if model_type == "normal": model = getattr(models, cfg.MODEL.NAME)(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.POOL) checkpoints = get_last_stats(cfg.OUTPUT_DIR, [cfg.MODEL.NAME]) model_dict = torch.load(checkpoints[cfg.MODEL.NAME]) model.load_state_dict(model_dict) elif model_type == "generalizer": import generalizers module = getattr(generalizers, cfg.MODEL.NAME) G = getattr(module, 'Generalizer_G')(num_classes, cfg.MODEL.LAST_STRIDE, cfg.MODEL.POOL) checkpoints = get_last_stats( cfg.OUTPUT_DIR, [str(type(G)), 'D_opt', 'G_opt', 'D_sch', 'G_sch', 'epo']) G.load_state_dict(torch.load(checkpoints[str(type(G))])) if device: # must be done before the optimizer generation G.to(device) model = G model = model.eval() x = [] y = [] imgs = [] NUM = num_batch if 'test' in vis_options: test_val_stats = [ data_loader(cfg, (target, ), merge=False, verbose=False)[1] for target in cfg.DATASETS.TARGET ] for i, val_loader in enumerate(tqdm.tqdm(test_val_stats, desc="")): count = 0 for data in val_loader: model.eval() with torch.no_grad(): images, pids, camids = data pids, images, camids = unique_by_pid(pids, images, camids) imgs.append(recover_image(images).data.numpy()) if device: model.to(device) images = images.to(device) feats = model(images) if metric == 'cosine': feats /= feats.norm(dim=-1, keepdim=True) x.append(feats.data.cpu().numpy()) y.extend([cfg.DATASETS.TARGET[i] + '_T' for e in pids]) count += 1 if count == NUM: break if 'train' in vis_options: train_val_stats = [ data_loader(cfg, (source, ), merge=False, verbose=False)[0] for source in cfg.DATASETS.SOURCE ] for i, train_loader in enumerate(train_val_stats): count = 0 for data in train_loader: model.eval() with torch.no_grad(): images, pids, camids = data pids, images, camids = unique_by_pid(pids, images, camids) imgs.append(recover_image(images).data.numpy()) if device: model.to(device) images = images.to(device) feats = model(images) if metric == 'cosine': feats /= feats.norm(dim=-1, keepdim=True) x.append(feats.data.cpu().numpy()) y.extend([cfg.DATASETS.SOURCE[i] + '_S' for e in pids]) count += 1 if count == NUM: break set_seeds(1) X = np.concatenate(x, 0) y = np.array(y) imgs = np.concatenate(imgs, 0) digits_proj = TSNE(random_state=RS, perplexity=perplexity, metric=metric).fit_transform(X) if mode == "img": imscatter(digits_proj, y, imgs) elif mode == "dot": scatter(digits_proj, y) else: raise ValueError("Mode can only be img or dot.") plt.savefig(os.path.join( save_dir, 'tsne-' + mode + '-' + metric + '_' + cfg.MODEL.NAME + '.png'), dpi=120)