def extract_video_features(self, data_loader, print_freq=1, metric=None): batch_time = AverageMeter() data_time = AverageMeter() f1 = torch.zeros(len(data_loader), 256) f2 = torch.zeros(len(data_loader), 256) f3 = torch.zeros(len(data_loader), 256) Pids, Camids = [], [] use_gpu = True for batch_idx, (imgs, pids, camids) in enumerate(data_loader): if use_gpu: imgs = Variable(imgs.cuda(), volatile=True) b, s, c, h, w = imgs.size() imgs = imgs.view(b * s, c, h, w) feat1, feat2, feat3 = self.cnnmodel(imgs) pool = 'avg' feat1 = feat1.view(b, s, -1) feat2 = feat2.view(b, s, -1) feat3 = feat3.view(b, s, -1) if pool == 'avg': feat1 = torch.mean(feat1, 1) feat2 = torch.mean(feat2, 1) feat3 = torch.mean(feat3, 1) else: feat1, _ = torch.max(feat1, 1) feat2, _ = torch.max(feat2, 1) feat3, _ = torch.max(feat3, 1) feat1 = feat1.data.cpu() feat2 = feat2.data.cpu() feat3 = feat3.data.cpu() f1[batch_idx, :] = feat1 f2[batch_idx, :] = feat2 f3[batch_idx, :] = feat3 Pids.extend(pids) Camids.extend(camids) end = time.time() batch_time.update(time.time() - end) if (batch_idx + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' 'num_frame {}\t'.format(batch_idx + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg, s)) Pids = np.asarray(Pids) Camids = np.asarray(Camids) print( "Extracted features for data set, obtained {}-by-{} matrix".format( f1.size(0), f1.size(1))) return f1, f2, f3, Pids, Camids
def extract_n_save(model, data_loader, args, root, num_cams, is_detection=True, use_fname=True, gt_type='reid'): model.eval() print_freq = 1000 batch_time = AverageMeter() data_time = AverageMeter() if_created = [0 for _ in range(num_cams)] lines = [[] for _ in range(num_cams)] end = time.time() for i, (imgs, fnames, pids, cams) in enumerate(data_loader): cams += 1 outputs = extract_cnn_feature(model, imgs) for fname, output, pid, cam in zip(fnames, outputs, pids, cams): if is_detection: pattern = re.compile(r'c(\d+)_f(\d+)') cam, frame = map(int, pattern.search(fname).groups()) # f_names[cam - 1].append(fname) # features[cam - 1].append(output.numpy()) line = np.concatenate( [np.array([cam, 0, frame]), output.numpy()]) else: if use_fname: pattern = re.compile(r'(\d+)_c(\d+)_f(\d+)') pid, cam, frame = map(int, pattern.search(fname).groups()) else: cam, pid = cam.numpy(), pid.numpy() frame = -1 * np.ones_like(pid) # line = output.numpy() line = np.concatenate( [np.array([cam, pid, frame]), output.numpy()]) lines[cam - 1].append(line) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) if_created = save_file(lines, args, root, if_created) lines = [[] for _ in range(num_cams)] save_file(lines, args, root, if_created) return
def extract_features(model, data_loader, eval_only, print_freq=100): model.eval() batch_time = AverageMeter() data_time = AverageMeter() features = [] labels = [] cameras = [] end = time.time() for i, (imgs, fnames, pids, cids) in enumerate(data_loader): data_time.update(time.time() - end) outputs = extract_cnn_feature(model, imgs, eval_only) for fname, output, pid, cid in zip(fnames, outputs, pids, cids): features.append(output) labels.append(int(pid.numpy())) cameras.append(int(cid.numpy())) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' .format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) output_features = torch.stack(features, 0) return output_features, labels, cameras
def extract_features(model, data_loader, print_freq=50, metric=None): model.eval() batch_time = AverageMeter() data_time = AverageMeter() features = OrderedDict() labels = OrderedDict() end = time.time() for i, (imgs, fnames, pids, _) in enumerate(data_loader): data_time.update(time.time() - end) outputs = extract_cnn_feature(model, imgs) for fname, output, pid in zip(fnames, outputs, pids): features[fname] = output labels[fname] = pid batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) return features, labels
def extractfeature(self, data_loader): ## print print_freq = 10 batch_time = AverageMeter() data_time = AverageMeter() end = time.time() queryfeat1 = 0 queryfeat2 = 0 queryfeat3 = 0 preimgs = 0 for i, (imgs, fnames, pids, _) in enumerate(data_loader): data_time.update(time.time() - end) imgs = Variable(imgs, volatile=True) if i == 0: query_feat1, query_feat2, query_feat3 = self.cnnmodel(imgs) queryfeat1 = query_feat1 queryfeat2 = query_feat2 queryfeat3 = query_feat3 preimgs = imgs elif imgs.size(0) < data_loader.batch_size: flaw_batchsize = imgs.size(0) cat_batchsize = data_loader.batch_size - flaw_batchsize imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0) query_feat1, query_feat2, query_feat3 = self.cnnmodel(imgs) query_feat1 = query_feat1[0:flaw_batchsize] query_feat2 = query_feat2[0:flaw_batchsize] query_feat3 = query_feat3[0:flaw_batchsize] queryfeat1 = torch.cat((queryfeat1, query_feat1), 0) queryfeat2 = torch.cat((queryfeat2, query_feat2), 0) queryfeat3 = torch.cat((queryfeat3, query_feat3), 0) else: query_feat1, query_feat2, query_feat3 = self.cnnmodel(imgs) queryfeat1 = torch.cat((queryfeat1, query_feat1), 0) queryfeat2 = torch.cat((queryfeat2, query_feat2), 0) queryfeat3 = torch.cat((queryfeat3, query_feat3), 0) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) return queryfeat1, queryfeat2, queryfeat3
def evaluate(self, queryloader, galleryloader, query, gallery): query_features = self.extractfeature(queryloader) batch_time = AverageMeter() data_time = AverageMeter() end = time.time() print_freq = 50 distmat = 0 self.cnnmodel.eval() self.classifier.eval() for i, (imgs, _, pids, _) in enumerate(galleryloader): data_time.update(time.time() - end) imgs = Variable(imgs, volatile=True) if i == 0: gallery_feat = self.cnnmodel(imgs) preimgs = imgs elif imgs.size(0) < galleryloader.batch_size: flaw_batchsize = imgs.size(0) cat_batchsize = galleryloader.batch_size - flaw_batchsize imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0) gallery_feat = self.cnnmodel(imgs) gallery_feat = gallery_feat[0:flaw_batchsize] else: gallery_feat = self.cnnmodel(imgs) batch_cls_encode = self.classifier(query_features, gallery_feat) batch_cls_size = batch_cls_encode.size() batch_cls_encode = batch_cls_encode.view(-1, 2) batch_cls_encode = F.softmax(batch_cls_encode) batch_cls_encode = batch_cls_encode.view(batch_cls_size[0], batch_cls_size[1], 2) batch_encode = batch_cls_encode[:, :, 0] if i == 0: distmat = batch_encode.data else: distmat = torch.cat((distmat, batch_encode.data), 1) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format( i + 1, len(galleryloader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) return evaluate_all(distmat, query=query, gallery=gallery)
def sim_computation(self, galleryloader, query_features): batch_time = AverageMeter() data_time = AverageMeter() end = time.time() print_freq = 50 simmat = 0 for i, (imgs, _, pids, _) in enumerate(galleryloader): data_time.update(time.time() - end) imgs = Variable(imgs, volatile=True) if i == 0: gallery_feat = self.cnnmodel(imgs) preimgs = imgs elif imgs.size(0) < galleryloader.batch_size: flaw_batchsize = imgs.size(0) cat_batchsize = galleryloader.batch_size - flaw_batchsize imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0) gallery_feat = self.cnnmodel(imgs) gallery_feat = gallery_feat[0:flaw_batchsize] else: gallery_feat = self.cnnmodel(imgs) batch_cls_encode = self.classifier(query_features, gallery_feat) batch_cls_size = batch_cls_encode.size() batch_cls_encode = batch_cls_encode.view(-1, 2) batch_cls_encode = F.softmax(batch_cls_encode) batch_cls_encode = batch_cls_encode.view(batch_cls_size[0], batch_cls_size[1], 2) batch_similarity = batch_cls_encode[:, :, 1] if i == 0: simmat = batch_similarity else: simmat = torch.cat((simmat, batch_similarity), 1) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format( i + 1, len(galleryloader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) return simmat
def extract_features(model, data_loader, print_freq=1, save_name='feature.mat'): batch_time = AverageMeter() data_time = AverageMeter() ids = [] cams = [] features = [] query_files = [] end = time.time() for i, (imgs, fnames) in enumerate(data_loader): data_time.update(time.time() - end) outputs = extract_cnn_feature(model, imgs) #for test time augmentation #bs, ncrops, c, h, w = imgs.size() #outputs = extract_cnn_feature(model, imgs.view(-1,c,h,w)) #outputs = outputs.view(bs,ncrops,-1).mean(1) for fname, output in zip(fnames, outputs): if fname[0] == '-': ids.append(-1) cams.append(int(fname[4])) else: ids.append(int(fname[:4])) cams.append(int(fname[6])) features.append(output.numpy()) query_files.append(fname) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) return features, ids, cams, query_files
def extract_features(model, data_loader, print_freq=1): model.eval() batch_time = AverageMeter() data_time = AverageMeter() features = OrderedDict() labels = OrderedDict() print('extract feature') end = time.time() for i, data in enumerate(data_loader): imgs, npys, fnames, pids = data.get('img'), data.get('npy'), data.get( 'fname'), data.get('pid') data_time.update(time.time() - end) outputs = extract_cnn_feature(model, [imgs, npys]) for fname, output, pid in zip(fnames, outputs, pids): features[fname] = output labels[fname] = pid batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print( 'Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg), imgs.shape) print( 'Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg), imgs.shape) print( f'{len(features)} features, each of len {features.values().__iter__().__next__().shape[0]}' ) return features, labels
def extract_embeddings( model, data_loader, print_freq=10, ): model.eval() batch_time = AverageMeter() data_time = AverageMeter() embeddings = [] print('extract embedding') end = time.time() for i, inputs in enumerate(data_loader): data_time.update(time.time() - end) outputs = extract_cnn_embeddings(model, inputs) # print(outputs.shape) embeddings.append(outputs) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Embedding: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) print('Extract embedding: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) res = torch.cat(embeddings) print(res.shape) return res
def inference(model, query_loader, gallery_loader, use_gpu): batch_time = AverageMeter() model.eval() with torch.no_grad(): qf = [] for batch_idx, (imgs, _) in enumerate(query_loader): if use_gpu: imgs = imgs.cuda() end = time.time() features = extract_cnn_feature(model, imgs) batch_time.update(time.time() - end) features = features.data.cpu() qf.extend(list(features)) gf, g_paths = [], [] for batch_idx, (imgs, path) in enumerate(gallery_loader): if use_gpu: imgs = imgs.cuda() end = time.time() features = extract_cnn_feature(model, imgs) batch_time.update(time.time() - end) features = features.data.cpu() gf.extend(list(features)) g_paths.extend(list(path)) print('=> BatchTime(s): {:.3f}'.format(batch_time.avg)) x = torch.cat([qf[i].unsqueeze(0) for i in range(len(qf))], 0) y = torch.cat([gf[i].unsqueeze(0) for i in range(len(gf))], 0) m, n = x.size(0), y.size(0) x = x.view(m, -1) y = y.view(n, -1) dist = torch.pow(x, 2).sum(1).unsqueeze(1).expand(m, n) + \ torch.pow(y, 2).sum(1).unsqueeze(1).expand(n, m).t() dist.addmm_(1, -2, x, y.t()) return dist
def trainMeta(meta_train_loader, meta_test_loader, net, noise, epoch, optimizer, centroids, metaCentroids, normalize): global args noise.requires_grad = True batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() mean = torch.Tensor(normalize.mean).view(1, 3, 1, 1).cuda() std = torch.Tensor(normalize.std).view(1, 3, 1, 1).cuda() net.eval() end = time.time() optimizer.zero_grad() optimizer.rescale() for i, ((input, _, pid, _), (metaTest, _, _, _)) in enumerate(zip(meta_train_loader, meta_test_loader)): # measure data loading time. data_time.update(time.time() - end) model.zero_grad() input = input.cuda() metaTest = metaTest.cuda() # one step update with torch.no_grad(): normInput = (input - mean) / std feature, realPred = net(normInput) scores = centroids.mm(F.normalize(feature.t(), p=2, dim=0)) # scores = centroids.mm(feature.t()) realLab = scores.max(0, keepdim=True)[1] _, ranks = torch.sort(scores, dim=0, descending=True) pos_i = ranks[0, :] neg_i = ranks[-1, :] neg_feature = centroids[neg_i, :] # centroids--512*2048 pos_feature = centroids[pos_i, :] current_noise = noise current_noise = F.interpolate( current_noise.unsqueeze(0), mode=MODE, size=tuple(input.shape[-2:]), align_corners=True, ).squeeze() perturted_input = torch.clamp(input + current_noise, 0, 1) perturted_input_norm = (perturted_input - mean) / std perturbed_feature = net(perturted_input_norm)[0] optimizer.zero_grad() pair_loss = 10 * F.triplet_margin_loss(perturbed_feature, neg_feature, pos_feature, 0.5) # clsScore = centroids.mm(perturbed_feature.t()).t() # oneHotReal = torch.zeros(clsScore.shape).cuda() # oneHotReal.scatter_(1, predLab.view(-1, 1), float(1)) # oneHotReal = F.normalize(1 - oneHotReal, p=1, dim=1) # label_loss = -(F.log_softmax(clsScore, 1) * oneHotReal).sum(1).mean() fakePred = centroids.mm(perturbed_feature.t()).t() oneHotReal = torch.zeros(scores.t().shape).cuda() oneHotReal.scatter_(1, realLab.view(-1, 1), float(1)) label_loss = F.relu( (fakePred * oneHotReal).sum(1).mean() - (fakePred * (1 - oneHotReal)).max(1)[0].mean() ) pair_loss = pair_loss.view(1) loss = pair_loss + label_loss # maml one step grad = torch.autograd.grad(loss, noise, create_graph=True)[0] noiseOneStep = keepGradUpdate(noise, optimizer, grad, MAX_EPS) # maml test newNoise = F.interpolate( noiseOneStep.unsqueeze(0), mode=MODE, size=tuple(metaTest.shape[-2:]), align_corners=True, ).squeeze() with torch.no_grad(): normMte = (metaTest - mean) / std mteFeat = net(normMte)[0] scores = metaCentroids.mm(F.normalize(mteFeat.t(), p=2, dim=0)) # scores = metaCentroids.mm(mteFeat.detach().t()) metaLab = scores.max(0, keepdim=True)[1] _, ranks = torch.sort(scores, dim=0, descending=True) pos_i = ranks[0, :] neg_i = ranks[-1, :] neg_mte_feat = metaCentroids[neg_i, :] # centroids--512*2048 pos_mte_feat = metaCentroids[pos_i, :] perMteInput = torch.clamp(metaTest + newNoise, 0, 1) normPerMteInput = (perMteInput - mean) / std normMteFeat = net(normPerMteInput)[0] lossMeta = 10 * F.triplet_margin_loss( normMteFeat, neg_mte_feat, pos_mte_feat, 0.5 ) fakePredMeta = metaCentroids.mm(normMteFeat.t()).t() oneHotRealMeta = torch.zeros(scores.t().shape).cuda() oneHotRealMeta.scatter_(1, metaLab.view(-1, 1), float(1)) labelLossMeta = F.relu( (fakePredMeta * oneHotRealMeta).sum(1).mean() - (fakePredMeta * (1 - oneHotRealMeta)).max(1)[0].mean() ) finalLoss = lossMeta + labelLossMeta + pair_loss + label_loss finalLoss.backward() losses.update(pair_loss.item()) optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print( ">> Train: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})\t" "LossMeta {lossMeta:.4f}\t" "Noise l2: {noise:.4f}".format( epoch + 1, i, len(meta_train_loader), batch_time=batch_time, data_time=data_time, loss=losses, lossMeta=lossMeta.item(), noise=noise.norm(), ) ) noise.requires_grad = False print(f"Train {epoch}: Loss: {losses.avg}") return losses.avg, noise
def run(): np.random.seed(opt.seed) torch.manual_seed(opt.seed) cudnn.benchmark = True data_dir = opt.data_dir # Redirect print to both console and log file #if not opt.evaluate: # sys.stdout = Logger(osp.join(opt.logs_dir, 'log_l2_per.txt')) # Create data loaders def readlist(path): lines = [] with open(path, 'r') as f: data = f.readlines() #pdb.set_trace() for line in data: name, pid, cam = line.split() lines.append((name, int(pid), int(cam))) return lines # Load data list for wuzhen if osp.exists(osp.join(data_dir, 'train.txt')): train_list = readlist(osp.join(data_dir, 'train.txt')) else: print("The training list doesn't exist") if osp.exists(osp.join(data_dir, 'val.txt')): val_list = readlist(osp.join(data_dir, 'val.txt')) else: print("The validation list doesn't exist") if osp.exists(osp.join(data_dir, 'query.txt')): query_list = readlist(osp.join(data_dir, 'query.txt')) else: print("The query.txt doesn't exist") if osp.exists(osp.join(data_dir, 'gallery.txt')): gallery_list = readlist(osp.join(data_dir, 'gallery.txt')) else: print("The gallery.txt doesn't exist") if opt.height is None or opt.width is None: opt.height, opt.width = (144, 56) if opt.arch == 'inception' else \ (256, 128) train_loader,val_loader, test_loader = \ get_data(opt.split, data_dir, opt.height, opt.width, opt.batchSize, opt.workers, opt.combine_trainval, train_list, val_list, query_list, gallery_list) # Create model # ori 14514; clear 12654, 16645 densenet = densenet121(num_classes=20330, num_features=256) start_epoch = best_top1 = 0 if opt.resume: #checkpoint = load_checkpoint(opt.resume) #densenet.load_state_dict(checkpoint['state_dict']) densenet.load_state_dict(torch.load(opt.resume)) start_epoch = opt.resume_epoch print("=> Finetune Start epoch {} ".format(start_epoch)) if opt.pretrained_model: print('Start load params...') load_params(densenet, opt.pretrained_model) # Load from checkpoint #densenet = nn.DataParallel(densenet).cuda() metric = DistanceMetric(algorithm=opt.dist_metric) print('densenet') show_info(densenet, with_arch=True, with_grad=False) netG = netg() print('netG') show_info(netG, with_arch=True, with_grad=False) netG.apply(weights_init) if opt.netG != '': netG.load_state_dict(torch.load(opt.netG)) #load_params(netG,opt.netG) if opt.cuda: netG = netG.cuda() densenet = densenet.cuda() perceptionloss = perception_loss(cuda=opt.cuda) l2loss = l2_loss(cuda=opt.cuda) # discriloss=discri_loss(cuda = opt.cuda,batchsize = opt.batchSize,height = \ # opt.height,width = opt.width,lr = opt.lr,step_size = \ # opt.step_size,decay_step = opt.decay_step ) # Evaluator evaluator = Evaluator(densenet) # if opt.evaluate: metric.train(densenet, train_loader) print("Validation:") evaluator.evaluate(val_loader, val_list, val_list, metric) print("Test:") evaluator.evaluate(test_loader, query_list, gallery_list, metric) # return # Criterion # criterion = nn.CrossEntropyLoss(ignore_index=-100).cuda() criterion = nn.CrossEntropyLoss().cuda() # Optimizer param_groups = [] mult_lr(densenet, param_groups) optimizer = optim.SGD(param_groups, lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) # optimizer = optim.Adam(param_groups, lr=opt.lr, betas=(opt.beta1, 0.9)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.9)) # Start training for epoch in range(start_epoch, opt.epochs): adjust_lr(optimizer, epoch) adjust_lr(optimizerG, epoch) #discriloss.adjust_lr(epoch) losses = AverageMeter() precisions = AverageMeter() densenet.train() for i, data in enumerate(train_loader): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # train with real real_cpu, _, pids, _ = data if opt.cuda: real_cpu = real_cpu.cuda() targets = Variable(pids.cuda()) input.resize_as_(real_cpu).copy_(real_cpu) inputv = Variable(input) outputs, output_dense, _ = densenet(inputv) fake = netG(output_dense) fake = fake * 3 #discriloss(fake = fake, inputv = inputv, i = i) ############################ # (2) Update G network: maximize log(D(G(z))) ########################### if i % opt.CRITIC_ITERS == 0: netG.zero_grad() optimizer.zero_grad() #loss_discri = discriloss.gloss(fake = fake) loss_l2 = l2loss(fake=fake, inputv=inputv) loss_perception = perceptionloss(fake=fake, inputv=inputv) loss_classify = criterion(outputs, targets) prec, = accuracy(outputs.data, targets.data) prec = prec[0] losses.update(loss_classify.data[0], targets.size(0)) precisions.update(prec, targets.size(0)) loss = loss_classify + 0 * loss_l2 + 0 * loss_perception # loss = loss_discri loss.backward() optimizerG.step() optimizer.step() #print(precisions.val) #print(precisions.avg) # print('[%d/%d][%d/%d] '%(epoch, opt.epochs, i, len(train_loader))) # print('[%d/%d][%d/%d] Loss_discri: %.4f '%(epoch, opt.epochs, i, \ # len(train_loader),loss_discri.data[0])) print('[%d/%d][%d/%d] Loss_l2: %.4f Loss_perception: %.4f '%(epoch, opt.epochs, i, \ len(train_loader),loss_l2.data[0],loss_perception.data[0])) print('Loss {}({})\t' 'Prec {}({})\t'.format(losses.val, losses.avg, precisions.val, precisions.avg)) if i % 100 == 0: vutils.save_image(real_cpu, '%s/real_samples.png' % opt.outf, normalize=True) outputs, output_dense, _ = densenet(x=inputv) fake = netG(output_dense) fake = fake * 3 vutils.save_image(fake.data, '%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch), normalize=True) show_info(densenet, with_arch=False, with_grad=True) show_info(netG, with_arch=False, with_grad=True) if epoch % 5 == 0: torch.save(densenet.state_dict(), '%s/densenet_epoch_%d.pth' % (opt.outf, epoch)) torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) if epoch < opt.start_save: continue top1 = evaluator.evaluate(val_loader, val_list, val_list) is_best = top1 > best_top1 best_top1 = max(top1, best_top1) save_checkpoint( { 'state_dict': densenet.state_dict(), 'epoch': epoch + 1, 'best_top1': best_top1, }, is_best, fpath=osp.join(opt.logs_dir, 'checkpoint.pth.tar')) print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'. format(epoch, top1, best_top1, ' *' if is_best else '')) if (epoch + 1) % 5 == 0: print('Test model: \n') evaluator.evaluate(test_loader, query_list, gallery_list) model_name = 'epoch_' + str(epoch) + '.pth.tar' torch.save({'state_dict': densenet.state_dict()}, osp.join(opt.logs_dir, model_name)) # Final test print('Test with best model:') checkpoint = load_checkpoint(osp.join(opt.logs_dir, 'model_best.pth.tar')) densenet.load_state_dict(checkpoint['state_dict']) print('best epoch: ', checkpoint['epoch']) metric.train(densenet, train_loader) evaluator.evaluate(test_loader, query_list, gallery_list, metric)
def trainMeta(meta_train_loader, meta_test_loader, net, epoch, normalize, perturbation): global args batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() mean = torch.Tensor(normalize.mean).view(1, 3, 1, 1).cuda() std = torch.Tensor(normalize.std).view(1, 3, 1, 1).cuda() net.eval() end = time.time() perturbation.zero_grad() optimizer.zero_grad() for i, (input, _, pids, _) in enumerate(meta_train_loader): metaTest, _, mtepids, _ = meta_test_loader.next() data_time.update(time.time() - end) model.zero_grad() input = input.cuda() metaTest = metaTest.cuda() # one step update with torch.no_grad(): norm_output = (input - mean) / std feature = net(norm_output)[0] current_noise = perturbation perturted_input = current_noise(input) perturted_input_clamp = torch.clamp(perturted_input, 0, 1) perturted_input_norm = (perturted_input_clamp - mean) / std perturbed_feature = net(perturted_input_norm)[0] optimizer.zero_grad() loss = TripletLoss()(feature, pids.cuda(), perturbed_feature) # maml one step noise = perturbation.parameters() grad = torch.autograd.grad(loss, noise, create_graph=True) noiseOneStep = keepGradUpdate(perturbation, optimizer, grad) perturbation_new = noiseOneStep #maml test with torch.no_grad(): normMte = (metaTest - mean) / std mteFeat = net(normMte)[0] perMteInput = perturbation_new(metaTest) perMteInput = torch.clamp(perMteInput, 0, 1) normPerMteInput = (perMteInput - mean) / std normMteFeat = net(normPerMteInput)[0] mteloss = TripletLoss()(mteFeat, mtepids.cuda(), normMteFeat) finalLoss = loss + mteloss finalLoss.backward() losses.update(loss.item()) optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print(">> Train: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" "Loss {loss.val:.4f} ({loss.avg:.4f})".format( epoch + 1, i, len(meta_train), batch_time=batch_time, data_time=data_time, loss=losses)) print(f"Train {epoch}: Loss: {losses.avg}") perturbation.state_dict().requires_grad = False return losses.avg, perturbation
def single_train(self, model, criterion, optimizer, trial): model.train() criterion.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() precisions = AverageMeter() iters = 0 start_time = time.time() end = time.time() for ep in range(self.num_epochs): for i, inputs in enumerate(self.data_loader): data_time.update(time.time() - end) iters += 1 inputs, targets = self._parse_data(inputs) loss, acc = self._forward(model, criterion, inputs, targets) losses.update(loss.item(), targets.size(0)) precisions.update(acc, targets.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() print('Trial {}: epoch [{}][{}/{}]. ' 'Time: {:.3f} ({:.3f}). ' 'Data: {:.3f} ({:.3f}). ' 'Metric: {:.4f} ({:.4f}). ' 'Loss: {:.3f} ({:.3f}). ' 'Prec: {:.2%} ({:.2%}).'.format( trial + 1, ep, i + 1, min(self.max_steps, len(self.data_loader)), batch_time.val, batch_time.avg, data_time.val, data_time.avg, precisions.val / losses.val, precisions.avg / losses.avg, losses.val, losses.avg, precisions.val, precisions.avg), end='\r', file=sys.stdout.console) if iters == self.max_steps - 1: break if iters == self.max_steps - 1: break loss = losses.avg acc = precisions.avg print( '* Trial %d. Metric: %.4f. Loss: %.3f. Acc: %.2f%%. Training time: %.0f seconds. \n' % (trial + 1, acc / loss, loss, acc * 100, time.time() - start_time)) return loss, acc
def train(args, model, train_loader, start_epoch): """Train classifier for source domain.""" #################### # 1. setup network # #################### base_param_ids = set(map(id, model.module.base.parameters())) new_params = [p for p in model.parameters() if id(p) not in base_param_ids] param_groups = [{ 'params': model.module.base.parameters(), 'lr_mult': 0.1 }, { 'params': new_params, 'lr_mult': 1.0 }] optimizer = optim.Adam(param_groups, lr=args.lr) # Criterion criterion = CapsuleLoss() criterion2 = nn.CrossEntropyLoss().cuda() criterion3 = nn.CrossEntropyLoss().cuda() # Schedule learning rate def adjust_lr(epoch): lr = args.lr * (0.1**(epoch // args.step_size)) for g in optimizer.param_groups: g['lr'] = lr * g.get('lr_mult', 1) #################### # 2. train network # #################### for epoch in range(start_epoch, args.epochs): adjust_lr(epoch) print_freq = 1 batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() precisions_id = AverageMeter() precisions_id2 = AverageMeter() precisions_id3 = AverageMeter() end = time.time() for i, inputs in enumerate(train_loader): data_time.update(time.time() - end) imgs, pids, imgname = inputs inputs = Variable(imgs.cuda()) labels = torch.eye(args.true_class).index_select(dim=0, index=pids) labels = Variable(labels.cuda()) targets = Variable(pids.cuda()) results, y, y2 = model(inputs) loss1 = criterion(imgs, labels, results) loss2 = criterion2(y, targets) loss3 = criterion3(y2, targets) prec, = accuracy_capsule(results.data, targets.data, args.true_class) prec = prec[0] prec2, = accuracy(y.data, targets.data) prec2 = prec2[0] prec3, = accuracy(y2.data, targets.data) prec3 = prec3[0] loss = loss1 + 0.5 * loss2 + 0.5 * loss3 # update the re-id model losses.update(loss.data.item(), targets.size(0)) precisions_id.update(prec, targets.size(0)) precisions_id2.update(prec2, targets.size(0)) precisions_id3.update(prec3, targets.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' 'Loss {:.3f} ({:.3f})\t' 'Prec_capslue {:.2%} ({:.2%})\t' 'Prec_ID2 {:.2%} ({:.2%})\t' 'Prec_ID3 {:.2%} ({:.2%})\t'.format( epoch, i + 1, len(train_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg, losses.val, losses.avg, precisions_id.val, precisions_id.avg, precisions_id2.val, precisions_id2.avg, precisions_id3.val, precisions_id3.avg)) # save model if (epoch + 1) % 5 == 0: save_checkpoint( { 'state_dict': model.state_dict(), 'epoch': epoch + 1, }, fpath=osp.join(args.logs_dir, 'checkpoint' + str(epoch + 1) + '.pth.tar')) print('\n * Finished epoch {:3d} \n'.format(epoch)) return model
def compute_distmat(self, queryloader, galleryloader): self.cnnmodel.eval() self.classifier.eval() queryfeat1, queryfeat2, queryfeat3 = self.extractfeature(queryloader) batch_time = AverageMeter() data_time = AverageMeter() end = time.time() print_freq = 50 distmat = 0 for i, (imgs, _, pids, _) in enumerate(galleryloader): data_time.update(time.time() - end) imgs = Variable(imgs, volatile=True) if i == 0: gallery_feat1, gallery_feat2, gallery_feat3 = self.cnnmodel( imgs) preimgs = imgs elif imgs.size(0) < galleryloader.batch_size: flaw_batchsize = imgs.size(0) cat_batchsize = galleryloader.batch_size - flaw_batchsize imgs = torch.cat((imgs, preimgs[0:cat_batchsize]), 0) gallery_feat1, gallery_feat2, gallery_feat3 = self.cnnmodel( imgs) gallery_feat1 = gallery_feat1[0:flaw_batchsize] gallery_feat2 = gallery_feat2[0:flaw_batchsize] gallery_feat3 = gallery_feat3[0:flaw_batchsize] else: gallery_feat1, gallery_feat2, gallery_feat3 = self.cnnmodel( imgs) batch_cls_encode1, batch_cls_encode2, batch_cls_encode3 = self.classifier( queryfeat1, gallery_feat1, queryfeat2, gallery_feat2, queryfeat3, gallery_feat3) batch_cls_size1 = batch_cls_encode1.size() batch_cls_encode1 = batch_cls_encode1.view(-1, 2) batch_cls_encode1 = F.softmax(batch_cls_encode1, 1) batch_cls_encode1 = batch_cls_encode1.view(batch_cls_size1[0], batch_cls_size1[1], 2) batch_cls_encode1 = batch_cls_encode1[:, :, 0] batch_cls_size2 = batch_cls_encode2.size() batch_cls_encode2 = batch_cls_encode2.view(-1, 2) batch_cls_encode2 = F.softmax(batch_cls_encode2, 1) batch_cls_encode2 = batch_cls_encode2.view(batch_cls_size2[0], batch_cls_size2[1], 2) batch_cls_encode2 = batch_cls_encode2[:, :, 0] batch_cls_size3 = batch_cls_encode3.size() batch_cls_encode3 = batch_cls_encode3.view(-1, 2) batch_cls_encode3 = F.softmax(batch_cls_encode3, 1) batch_cls_encode3 = batch_cls_encode3.view(batch_cls_size3[0], batch_cls_size3[1], 2) batch_cls_encode3 = batch_cls_encode3[:, :, 0] batch_cls_encode = batch_cls_encode1 * self.alphas[ 0] + batch_cls_encode2 * self.alphas[ 1] + batch_cls_encode3 * self.alphas[2] if i == 0: distmat = batch_cls_encode.data else: distmat = torch.cat((distmat, batch_cls_encode.data), 1) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format( i + 1, len(galleryloader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) return distmat
if len(gpus) < 2: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) # If multi gpus if len(gpus) > 1: model = torch.nn.DataParallel(model, range(len(args.gpus))).cuda() if args.pretrained_weights_dir: model = torch.load(args.pretrained_weights_dir) else: model = torch.load(os.path.join(exp_dir, 'model.pth')) model.eval() batch_time = AverageMeter() data_time = AverageMeter() features = OrderedDict() labels = OrderedDict() end = time.time() print('Extracting features... This may take a while...') with torch.no_grad(): for i, (imgs, fnames, pids, _) in enumerate(test_loader): data_time.update(time.time() - end) imgs_flip = torch.flip(imgs, [3]) final_feat_list, _, _, _, _, = model(Variable(imgs).cuda()) final_feat_list_flip, _, _, _, _ = model(Variable(imgs_flip).cuda())
def main(args): args = parser.parse_args() np.random.seed(args.seed) torch.manual_seed(args.seed) args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.benchmark = True # Redirect print to both console and log file if not args.evaluate: sys.stdout = Logger(osp.join(args.logs_dir, 'Part_log.txt')) # Create data loaders assert args.num_instances > 1, "num_instances should be greater than 1" assert args.batch_size % args.num_instances == 0, \ 'num_instances should divide batch_size' if args.height is None or args.width is None: args.height, args.width = (144, 56) if args.arch == 'inception' else \ (256, 128) dataset, num_classes, train_loader, val_loader, test_loader = \ get_data(args.dataset, args.split, args.data_dir, args.height, args.width, args.batch_size, args.num_instances, args.workers, args.combine_trainval) # Create model # Hacking here to let the classifier be the last feature embedding layer # Net structure: avgpool -> FC(1024) -> FC(args.features) model = models.create(args.arch, num_features=512, pretrained=True, dropout=args.dropout, num_classes=args.features, embedding=False) # Load from checkpoint start_epoch = best_top1 = 0 if args.resume: checkpoint = load_checkpoint(args.resume) model.load_state_dict(checkpoint['state_dict']) #start_epoch = checkpoint['epoch'] start_epoch = 0 best_top1 = checkpoint['best_top1'] print("=> Start epoch {} best top1 {:.1%}".format( start_epoch, best_top1)) model = nn.DataParallel(model) #model = nn.DataParallel(model).cpu() if args.cuda: model.cuda() # Distance metric metric = DistanceMetric(algorithm=args.dist_metric) # Evaluator evaluator = Evaluator(model) if args.evaluate: metric.train(model, train_loader) print("Validation:") evaluator.evaluate(val_loader, dataset.val, dataset.val, metric) print("Test:") evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric) return # Criterion # criterion = TripletLoss(margin=args.margin).cpu() criterion = TripletLoss(margin=args.margin) if args.cuda: criterion.cuda() # # Optimizer optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) ''' optimizer = torch.optim.Adam([{'params': model.module.w1.parameters(), 'lr': 1e-6, 'weight_decay': 5e-4}, {'params': model.module.w2.parameters(), 'lr': 1e-6, 'weight_decay': 5e-4}, {'params': model.module.w3.parameters(), 'lr': 1e-6, 'weight_decay': 5e-4}, {'params': model.module.w4.parameters(), 'lr': 1e-6, 'weight_decay': 5e-4}, {'params': model.module.w5.parameters(), 'lr': 1e-6, 'weight_decay': 5e-4}], lr=args.lr, weight_decay=args.weight_decay)''' # Trainer trainer = Trainer(model, criterion) # Schedule learning rate def adjust_lr(epoch): lr = args.lr if epoch <= 100 else \ args.lr * (0.001 ** ((epoch - 100) / 50.0)) for g in optimizer.param_groups: g['lr'] = lr * g.get('lr_mult', 1) # Start training accs_market = AverageMeter() accs_cuhk03 = AverageMeter() for epoch in range(start_epoch, args.epochs): adjust_lr(epoch) trainer.train(epoch, train_loader, optimizer) if epoch < args.start_save: continue top1, cuhk03_top1, market_top1 = evaluator.evaluate( val_loader, dataset.val, dataset.val) accs_market.update(market_top1, args.batch_size * 40) accs_cuhk03.update(cuhk03_top1, args.batch_size * 40) plotter.plot('acc', 'test-multishot', epoch, market_top1) plotter.plot('acc', 'test-singleshot', epoch, cuhk03_top1) is_best = top1 > best_top1 best_top1 = max(top1, best_top1) save_checkpoint( { 'state_dict': model.module.state_dict(), 'epoch': epoch + 1, 'best_top1': best_top1, }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'. format(epoch, top1, best_top1, ' *' if is_best else '')) # Final test print('Test with best model:') checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) model.module.load_state_dict(checkpoint['state_dict']) metric.train(model, train_loader) evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
def train(self, epoch, data_loader, optimizer): self.model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() precisions = AverageMeter() precisions1 = AverageMeter() precisions2 = AverageMeter() end = time.time() for i, inputs in enumerate(data_loader): data_time.update(time.time() - end) inputs, targets = self._parse_data(inputs) loss, prec_oim, loss_score, prec_finalscore = self._forward( inputs, targets, i) losses.update(loss.data.item(), targets.size(0)) precisions.update(prec_oim, targets.size(0)) precisions1.update(loss_score.data.item(), targets.size(0)) precisions2.update(prec_finalscore, targets.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() print_freq = 50 if (i + 1) % print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Loss {:.3f} ({:.3f})\t' 'prec_oim {:.2%} ({:.2%})\t' 'prec_score {:.2%} ({:.2%})\t' 'prec_finalscore(total) {:.2%} ({:.2%})\t'.format( epoch, i + 1, len(data_loader), losses.val, losses.avg, precisions.val, precisions.avg, precisions1.val, precisions1.avg, precisions2.val, precisions2.avg))
# If multi gpus if len(gpus) > 1: model = torch.nn.DataParallel(model, range(len(args.gpus))).cuda() # Training for epoch in range(1, args.epochs+1): adjust_lr_staircase( optimizer.param_groups, [args.base_lr, args.lr], epoch, decay_schedule, args.staircase_decay_multiply_factor) model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() precisions = AverageMeter() end = time.time() for i, inputs in enumerate(train_loader): data_time.update(time.time() - end) (imgs, _, labels, _) = inputs inputs = Variable(imgs).float().cuda() labels = Variable(labels).cuda() optimizer.zero_grad() final_feat_list, logits_local_rest_list, logits_local_list, logits_rest_list, logits_global_list = model(inputs)
def train(train_loader, net, noise, epoch, optimizer, centroids, normalize): global args noise.requires_grad = True batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() mean = torch.Tensor(normalize.mean).view(1, 3, 1, 1).cuda() std = torch.Tensor(normalize.std).view(1, 3, 1, 1).cuda() net.eval() end = time.time() optimizer.zero_grad() optimizer.rescale() for i, (input, _, _, _) in enumerate(train_loader): # measure data loading time. data_time.update(time.time() - end) model.zero_grad() input = input.cuda() with torch.no_grad(): norm_output = (input - mean) / std feature = net(norm_output)[0] scores = centroids.mm(F.normalize(feature.t(), p=2, dim=0)) realLab = scores.max(0, keepdim=True)[1] _, ranks = torch.sort(scores, dim=0, descending=True) pos_i = ranks[0, :] neg_i = ranks[-1, :] neg_feature = centroids[neg_i, :].view(-1, 2048) # centroids--512*2048 pos_feature = centroids[pos_i, :].view(-1, 2048) current_noise = noise current_noise = F.interpolate( current_noise.unsqueeze(0), mode=MODE, size=tuple(input.shape[-2:]), align_corners=True, ).squeeze() perturted_input = torch.clamp(input + current_noise, 0, 1) perturted_input_norm = (perturted_input - mean) / std perturbed_feature = net(perturted_input_norm)[0] optimizer.zero_grad() pair_loss = 10 * F.triplet_margin_loss(perturbed_feature, neg_feature, pos_feature, 0.5) fakePred = centroids.mm(perturbed_feature.t()).t() oneHotReal = torch.zeros(scores.t().shape).cuda() oneHotReal.scatter_(1, realLab.view(-1, 1), float(1)) label_loss = F.relu((fakePred * oneHotReal).sum(1).mean() - (fakePred * (1 - oneHotReal)).max(1)[0].mean()) loss = pair_loss + label_loss loss.backward() losses.update(loss.item()) optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: print(">> Train: [{0}][{1}/{2}]\t" "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t" "Data {data_time.val:.3f} ({data_time.avg:.3f})\t" "PairLoss {loss:.4f}\t" "LabelLoss {lossLab:.4f}\t" "Noise l2: {noise:.4f}".format( epoch + 1, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=pair_loss.item(), lossLab=label_loss.item(), noise=noise.norm(), )) noise.requires_grad = False print(f"Train {epoch}: Loss: {losses.avg}") return losses.avg, noise
def train(self, epoch,mt_train_loader, mt_test_loader, optimizer,noise_model,args): self.model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() precisions = AverageMeter() end = time.time() for i ,(inputs) in enumerate(mt_train_loader): meta_input = mt_test_loader.next() img, _, pid, _ = inputs metaTest, _, meta_pid, _ = meta_input adv_inputs_total, adv_labels, adv_labels_total, coupled_inputs = [], [], [], [] adv_inputs_total_meta, adv_labels_meta, adv_labels_total_meta, coupled_inputs_meta = [], [], [], [] ###generate perturbed images during meta train## adv_data = create_attack_exp(inputs, noise_model) adv_inputs, adv_labels, adv_idxs, og_adv_inputs = adv_data adv_inputs_total.append(adv_inputs) adv_labels_total.append(adv_labels) coupled_inputs.append(og_adv_inputs) inputs = torch.cat([img.cuda()] + [_.data for _ in adv_inputs_total], dim=0) labels = torch.cat([pid.cuda()] + [_.data for _ in adv_labels_total], dim=0) inputs, pid = Variable(inputs), Variable(labels) finall_input = inputs.cuda() targets = pid.cuda() ###generate perturbed images during meta test## adv_data = create_attack_exp(meta_input, noise_model) adv_inputs, adv_labels, adv_idxs, og_adv_inputs = adv_data adv_inputs_total_meta.append(adv_inputs) adv_labels_total_meta.append(adv_labels) coupled_inputs_meta.append(og_adv_inputs) meta_input = torch.cat([metaTest.cuda()] + [_.data for _ in adv_inputs_total_meta], dim=0) meta_pid = torch.cat([meta_pid.cuda()] + [_.data for _ in adv_labels_total_meta], dim=0) meta_input = meta_input.cuda() meta_pid = meta_pid.cuda() data_time.update(time.time() - end) ###meta train#### cur_model=self.model output = cur_model(finall_input) loss, prec1 = self._memory(output, targets, epoch) self.model.zero_grad() grads = torch.autograd.grad(loss, (self.model.module.params()), create_graph=True) lr = optimizer.param_groups[0]["lr"] lr_base = optimizer.param_groups[1]["lr"] ###meta test### newMeta = models.create('resMeta', num_classes=class_meta) newMeta.copyModel(self.model.module) newMeta.update_params(lr_inner=lr, lr_base=lr_base, source_params=grads, solver='adam') del grads newMeta = nn.DataParallel(newMeta).to(self.device) meta_out = newMeta(meta_input) metaloss, prec2 = self._memory(meta_out, meta_pid, epoch) ###### loss_finall = metaloss + loss optimizer.zero_grad() loss_finall.backward() optimizer.step() losses.update(loss.item(), targets.size(0)) precisions.update(prec1, targets.size(0)) batch_time.update(time.time() - end) end = time.time() if (i + 1) % self.print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' 'Loss {:.3f} ({:.3f})\t' 'Prec {:.2%} ({:.2%})\t' .format(epoch, i + 1, len(mt_train_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg, losses.val, losses.avg, precisions.val, precisions.avg))
def extract_features(model, data_loader, is_flip=False, print_freq=1, metric=None): model.eval() batch_time = AverageMeter() data_time = AverageMeter() features = OrderedDict() end = time.time() if is_flip: print('flip') for i, (imgs, flip_imgs, fnames) in enumerate(data_loader): data_time.update(time.time() - end) outputs = extract_cnn_feature(model, imgs) flip_outputs = extract_cnn_feature(model, flip_imgs) final_outputs = (outputs + flip_outputs) / 2 for fname, output in zip(fnames, final_outputs): features[fname] = output.numpy() batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) else: print('no flip') for i, (imgs, fnames) in enumerate(data_loader): data_time.update(time.time() - end) outputs = extract_cnn_feature(model, imgs) for fname, output in zip(fnames, outputs): features[fname] = output.numpy() batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format(i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) return features
def compute_distmat(self, queryloader, galleryloader): self.cnnmodel.eval() self.classifier.eval() queryfeat1, queryfeat2, queryfeat3, q_ids, q_cams = self.extract_video_features( queryloader) galleryfeat1, galleryfeat2, galleryfeat3, g_ids, g_cams = self.extract_video_features( galleryloader) batch_time = AverageMeter() data_time = AverageMeter() end = time.time() print_freq = 50 distmat = 0 step_size = 32 N = math.ceil(galleryfeat1.size(0) / (step_size + 0.001)) for i in range(N): #print('@@@@@@ {} {}'.format(i, N)) idx_bg = i * step_size idx_end = ( i + 1) * step_size if (i + 1) * step_size < galleryfeat1.size( 0) else galleryfeat1.size(0) #print('@@@@@{} {}'.format(idx_bg,idx_end)) gallery_feat1 = galleryfeat1[idx_bg:idx_end] gallery_feat2 = galleryfeat2[idx_bg:idx_end] gallery_feat3 = galleryfeat3[idx_bg:idx_end] queryfeat1 = queryfeat1.cuda() gallery_feat1 = gallery_feat1.cuda() queryfeat2 = queryfeat2.cuda() gallery_feat2 = gallery_feat2.cuda() queryfeat3 = queryfeat3.cuda() gallery_feat3 = gallery_feat3.cuda() batch_cls_encode1, batch_cls_encode2, batch_cls_encode3 = self.classifier( queryfeat1, gallery_feat1, queryfeat2, gallery_feat2, queryfeat3, gallery_feat3) batch_cls_size1 = batch_cls_encode1.size() batch_cls_encode1 = batch_cls_encode1.view(-1, 2) batch_cls_encode1 = F.softmax(batch_cls_encode1, 1) batch_cls_encode1 = batch_cls_encode1.view(batch_cls_size1[0], batch_cls_size1[1], 2) batch_cls_encode1 = batch_cls_encode1[:, :, 0] batch_cls_size2 = batch_cls_encode2.size() batch_cls_encode2 = batch_cls_encode2.view(-1, 2) batch_cls_encode2 = F.softmax(batch_cls_encode2, 1) batch_cls_encode2 = batch_cls_encode2.view(batch_cls_size2[0], batch_cls_size2[1], 2) batch_cls_encode2 = batch_cls_encode2[:, :, 0] batch_cls_size3 = batch_cls_encode3.size() batch_cls_encode3 = batch_cls_encode3.view(-1, 2) batch_cls_encode3 = F.softmax(batch_cls_encode3, 1) batch_cls_encode3 = batch_cls_encode3.view(batch_cls_size3[0], batch_cls_size3[1], 2) batch_cls_encode3 = batch_cls_encode3[:, :, 0] batch_cls_encode = batch_cls_encode1 * self.alphas[ 0] + batch_cls_encode2 * self.alphas[ 1] + batch_cls_encode3 * self.alphas[2] if i == 0: distmat = batch_cls_encode.data else: distmat = torch.cat((distmat, batch_cls_encode.data), 1) batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Extract Features: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t'.format( i + 1, len(galleryloader), batch_time.val, batch_time.avg, data_time.val, data_time.avg)) return distmat, q_ids, q_cams, g_ids, g_cams
def train(self, epoch, data_loader, optimizer): self.model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() precisions = AverageMeter() end = time.time() for i, inputs in enumerate(data_loader): data_time.update(time.time() - end) inputs, targets = self._parse_data(inputs) loss, prec1 = self._forward(inputs, targets, epoch) losses.update(loss.data[0], targets.size(0)) precisions.update(prec1, targets.size(0)) optimizer.zero_grad() loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() if (i + 1) % self.print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' 'Loss {:.3f} ({:.3f})\t' 'Prec {:.2%} ({:.2%})\t'.format( epoch, i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg, losses.val, losses.avg, precisions.val, precisions.avg))