def testSyncBatchNormSyncEval(self): bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) bn.cuda() sync_bn.cuda() self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
def testSyncBatchNorm2DSyncTrain(self): bn = nn.BatchNorm2d(10) sync_bn = SynchronizedBatchNorm2d(10) sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) bn.cuda() sync_bn.cuda() self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
def testSyncBatchNorm2DSyncTrain(self): bn = nn.BatchNorm2d(10) sync_bn = SynchronizedBatchNorm2d(10) sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) bn.cuda() sync_bn.cuda() self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
def testSyncBatchNormSyncEval(self): bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) bn.cuda() sync_bn.cuda() self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
class Model: def __init__(self, hidden_dim, lr, hard_or_full_trip, margin, num_workers, batch_size, restore_iter, total_iter, save_name, train_pid_num, frame_num, model_name, train_source, test_source, img_size=64): self.save_name = save_name self.train_pid_num = train_pid_num self.train_source = train_source self.test_source = test_source self.hidden_dim = hidden_dim self.lr = lr self.hard_or_full_trip = hard_or_full_trip self.margin = margin self.frame_num = frame_num self.num_workers = num_workers self.batch_size = batch_size self.model_name = model_name self.P, self.M = batch_size self.restore_iter = restore_iter self.total_iter = total_iter self.img_size = img_size self.encoder = SetNet(self.hidden_dim).float() self.encoder = DataParallelWithCallback(self.encoder) self.triplet_loss = TripletLoss(self.P * self.M, self.hard_or_full_trip, self.margin).float() self.triplet_loss = DataParallelWithCallback(self.triplet_loss) self.encoder.cuda() self.triplet_loss.cuda() self.optimizer = optim.Adam([ {'params': self.encoder.parameters()}, ], lr=self.lr) self.hard_loss_metric = [] self.full_loss_metric = [] self.full_loss_num = [] self.dist_list = [] self.mean_dist = 0.01 self.sample_type = 'all' def collate_fn(self, batch): batch_size = len(batch) feature_num = len(batch[0][0]) seqs = [batch[i][0] for i in range(batch_size)] frame_sets = [batch[i][1] for i in range(batch_size)] view = [batch[i][2] for i in range(batch_size)] seq_type = [batch[i][3] for i in range(batch_size)] label = [batch[i][4] for i in range(batch_size)] batch = [seqs, view, seq_type, label, None] def select_frame(index): sample = seqs[index] frame_set = frame_sets[index] if self.sample_type == 'random': frame_id_list = random.choices(frame_set, k=self.frame_num) _ = [feature.loc[frame_id_list].values for feature in sample] else: _ = [feature.values for feature in sample] return _ seqs = list(map(select_frame, range(len(seqs)))) if self.sample_type == 'random': seqs = [np.asarray([seqs[i][j] for i in range(batch_size)]) for j in range(feature_num)] else: gpu_num = min(torch.cuda.device_count(), batch_size) batch_per_gpu = math.ceil(batch_size / gpu_num) batch_frames = [[ len(frame_sets[i]) for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) if i < batch_size ] for _ in range(gpu_num)] if len(batch_frames[-1]) != batch_per_gpu: for _ in range(batch_per_gpu - len(batch_frames[-1])): batch_frames[-1].append(0) max_sum_frame = np.max([np.sum(batch_frames[_]) for _ in range(gpu_num)]) seqs = [[ np.concatenate([ seqs[i][j] for i in range(batch_per_gpu * _, batch_per_gpu * (_ + 1)) if i < batch_size ], 0) for _ in range(gpu_num)] for j in range(feature_num)] seqs = [np.asarray([ np.pad(seqs[j][_], ((0, max_sum_frame - seqs[j][_].shape[0]), (0, 0), (0, 0)), 'constant', constant_values=0) for _ in range(gpu_num)]) for j in range(feature_num)] batch[4] = np.asarray(batch_frames) batch[0] = seqs return batch def fit(self): if self.restore_iter != 0: self.load(self.restore_iter) self.encoder.train() self.sample_type = 'random' for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr triplet_sampler = TripletSampler(self.train_source, self.batch_size) train_loader = tordata.DataLoader( dataset=self.train_source, batch_sampler=triplet_sampler, collate_fn=self.collate_fn, num_workers=self.num_workers) train_label_set = list(self.train_source.label_set) train_label_set.sort() _time1 = datetime.now() for seq, view, seq_type, label, batch_frame in train_loader: self.restore_iter += 1 self.optimizer.zero_grad() for i in range(len(seq)): seq[i] = self.np2var(seq[i]).float() if batch_frame is not None: batch_frame = self.np2var(batch_frame).int() feature, label_prob = self.encoder(*seq, batch_frame) target_label = [train_label_set.index(l) for l in label] target_label = self.np2var(np.array(target_label)).long() triplet_feature = feature.permute(1, 0, 2).contiguous() triplet_label = target_label.unsqueeze(0).repeat(triplet_feature.size(0), 1) (full_loss_metric, hard_loss_metric, mean_dist, full_loss_num ) = self.triplet_loss(triplet_feature, triplet_label) if self.hard_or_full_trip == 'hard': loss = hard_loss_metric.mean() elif self.hard_or_full_trip == 'full': loss = full_loss_metric.mean() self.hard_loss_metric.append(hard_loss_metric.mean().data.cpu().numpy()) self.full_loss_metric.append(full_loss_metric.mean().data.cpu().numpy()) self.full_loss_num.append(full_loss_num.mean().data.cpu().numpy()) self.dist_list.append(mean_dist.mean().data.cpu().numpy()) if loss > 1e-9: loss.backward() self.optimizer.step() if self.restore_iter % 1000 == 0: print(datetime.now() - _time1) _time1 = datetime.now() if self.restore_iter % 100 == 0: self.save() print('iter {}:'.format(self.restore_iter), end='') print(', hard_loss_metric={0:.8f}'.format(np.mean(self.hard_loss_metric)), end='') print(', full_loss_metric={0:.8f}'.format(np.mean(self.full_loss_metric)), end='') print(', full_loss_num={0:.8f}'.format(np.mean(self.full_loss_num)), end='') self.mean_dist = np.mean(self.dist_list) print(', mean_dist={0:.8f}'.format(self.mean_dist), end='') print(', lr=%f' % self.optimizer.param_groups[0]['lr'], end='') print(', hard or full=%r' % self.hard_or_full_trip) sys.stdout.flush() self.hard_loss_metric = [] self.full_loss_metric = [] self.full_loss_num = [] self.dist_list = [] # Visualization using t-SNE # if self.restore_iter % 500 == 0: # pca = TSNE(2) # pca_feature = pca.fit_transform(feature.view(feature.size(0), -1).data.cpu().numpy()) # for i in range(self.P): # plt.scatter(pca_feature[self.M * i:self.M * (i + 1), 0], # pca_feature[self.M * i:self.M * (i + 1), 1], label=label[self.M * i]) # # plt.show() if self.restore_iter == self.total_iter: break def ts2var(self, x): return autograd.Variable(x).cuda() def np2var(self, x): return self.ts2var(torch.from_numpy(x)) def transform(self, flag, batch_size=1): self.encoder.eval() source = self.test_source if flag == 'test' else self.train_source self.sample_type = 'all' data_loader = tordata.DataLoader( dataset=source, batch_size=batch_size, sampler=tordata.sampler.SequentialSampler(source), collate_fn=self.collate_fn, num_workers=self.num_workers) feature_list = list() view_list = list() seq_type_list = list() label_list = list() for i, x in enumerate(data_loader): seq, view, seq_type, label, batch_frame = x for j in range(len(seq)): seq[j] = self.np2var(seq[j]).float() if batch_frame is not None: batch_frame = self.np2var(batch_frame).int() # print(batch_frame, np.sum(batch_frame)) feature, _ = self.encoder(*seq, batch_frame) n, num_bin, _ = feature.size() feature_list.append(feature.view(n, -1).data.cpu().numpy()) view_list += view seq_type_list += seq_type label_list += label return np.concatenate(feature_list, 0), view_list, seq_type_list, label_list def save(self): os.makedirs(osp.join('checkpoint', self.model_name), exist_ok=True) torch.save(self.encoder.state_dict(), osp.join('checkpoint', self.model_name, '{}-{:0>5}-encoder.ptm'.format( self.save_name, self.restore_iter))) torch.save(self.optimizer.state_dict(), osp.join('checkpoint', self.model_name, '{}-{:0>5}-optimizer.ptm'.format( self.save_name, self.restore_iter))) # restore_iter: iteration index of the checkpoint to load def load(self, restore_iter): self.encoder.load_state_dict(torch.load(osp.join( 'checkpoint', self.model_name, '{}-{:0>5}-encoder.ptm'.format(self.save_name, restore_iter)))) self.optimizer.load_state_dict(torch.load(osp.join( 'checkpoint', self.model_name, '{}-{:0>5}-optimizer.ptm'.format(self.save_name, restore_iter))))