def __init__(self, model, optimizer, train_loader, valid_loader, margin, lambda_, patience, verbose=-1, device=0, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, swap=False, softmax=False, pretrain=False, mining=False, cuda=True): if checkpoint_path is None: # Save to current directory self.checkpoint_path = os.getcwd() else: self.checkpoint_path = checkpoint_path if not os.path.isdir(self.checkpoint_path): os.mkdir(self.checkpoint_path) self.save_epoch_fmt = os.path.join(self.checkpoint_path, cp_name) if cp_name else os.path.join(self.checkpoint_path, 'checkpoint_{}ep.pt') self.cuda_mode = cuda self.softmax = softmax!='none' self.pretrain = pretrain self.mining = mining self.model = model self.swap = swap self.lambda_ = lambda_ self.optimizer = optimizer self.train_loader = train_loader self.valid_loader = valid_loader self.total_iters = 0 self.cur_epoch = 0 self.margin = margin self.harvester = HardestNegativeTripletSelector(margin=self.margin, cpu=not self.cuda_mode) self.verbose = verbose self.save_cp = save_cp self.device = device self.history = {'train_loss': [], 'train_loss_batch': []} if self.valid_loader is not None: self.history['valid_loss'] = [] self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, patience=patience, verbose=True if self.verbose>0 else False, threshold=1e-4, min_lr=1e-7) else: self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[20, 100, 200, 300, 400], gamma=0.1) if self.softmax: self.history['softmax_batch']=[] self.history['softmax']=[] if checkpoint_epoch is not None: self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))
def __init__(self, model, optimizer, train_loader, valid_loader, margin, lambda_, label_smoothing, warmup_its, max_gnorm=10.0, verbose=-1, device=0, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, swap=False, softmax=False, pretrain=False, mining=False, cuda=True, logger=None): if checkpoint_path is None: # Save to current directory self.checkpoint_path = os.getcwd() else: self.checkpoint_path = checkpoint_path if not os.path.isdir(self.checkpoint_path): os.mkdir(self.checkpoint_path) self.save_epoch_fmt = os.path.join( self.checkpoint_path, cp_name) if cp_name else os.path.join( self.checkpoint_path, 'checkpoint_{}ep.pt') self.cuda_mode = cuda self.softmax = softmax != 'none' self.pretrain = pretrain self.mining = mining self.model = model self.swap = swap self.lambda_ = lambda_ self.optimizer = optimizer self.max_gnorm = max_gnorm self.train_loader = train_loader self.valid_loader = valid_loader self.total_iters = 0 self.cur_epoch = 0 self.margin = margin self.harvester_mine = HardestNegativeTripletSelector( margin=self.margin, cpu=not self.cuda_mode) self.harvester_all = AllTripletSelector() self.verbose = verbose self.save_cp = save_cp self.device = device self.history = {'train_loss': [], 'train_loss_batch': []} self.logger = logger its_per_epoch = len( train_loader.dataset) // (train_loader.batch_size) + 1 if len( train_loader.dataset) % (train_loader.batch_size) > 0 else len( train_loader.dataset) // (train_loader.batch_size) if self.softmax: if label_smoothing > 0.0: self.ce_criterion = LabelSmoothingLoss( label_smoothing, lbl_set_size=train_loader.dataset.n_speakers) else: self.ce_criterion = torch.nn.CrossEntropyLoss() if self.valid_loader is not None: self.history['valid_loss'] = [] if self.softmax: self.history['softmax_batch'] = [] self.history['softmax'] = [] if checkpoint_epoch is not None: self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))
class TrainLoop(object): def __init__(self, model, optimizer, train_loader, valid_loader, margin, lambda_, label_smoothing, warmup_its, max_gnorm=10.0, verbose=-1, device=0, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, swap=False, softmax=False, pretrain=False, mining=False, cuda=True, logger=None): if checkpoint_path is None: # Save to current directory self.checkpoint_path = os.getcwd() else: self.checkpoint_path = checkpoint_path if not os.path.isdir(self.checkpoint_path): os.mkdir(self.checkpoint_path) self.save_epoch_fmt = os.path.join( self.checkpoint_path, cp_name) if cp_name else os.path.join( self.checkpoint_path, 'checkpoint_{}ep.pt') self.cuda_mode = cuda self.softmax = softmax != 'none' self.pretrain = pretrain self.mining = mining self.model = model self.swap = swap self.lambda_ = lambda_ self.optimizer = optimizer self.max_gnorm = max_gnorm self.train_loader = train_loader self.valid_loader = valid_loader self.total_iters = 0 self.cur_epoch = 0 self.margin = margin self.harvester_mine = HardestNegativeTripletSelector( margin=self.margin, cpu=not self.cuda_mode) self.harvester_all = AllTripletSelector() self.verbose = verbose self.save_cp = save_cp self.device = device self.history = {'train_loss': [], 'train_loss_batch': []} self.logger = logger its_per_epoch = len( train_loader.dataset) // (train_loader.batch_size) + 1 if len( train_loader.dataset) % (train_loader.batch_size) > 0 else len( train_loader.dataset) // (train_loader.batch_size) if self.softmax: if label_smoothing > 0.0: self.ce_criterion = LabelSmoothingLoss( label_smoothing, lbl_set_size=train_loader.dataset.n_speakers) else: self.ce_criterion = torch.nn.CrossEntropyLoss() if self.valid_loader is not None: self.history['valid_loss'] = [] if self.softmax: self.history['softmax_batch'] = [] self.history['softmax'] = [] if checkpoint_epoch is not None: self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch)) def train(self, n_epochs=1, save_every=1): while (self.cur_epoch < n_epochs): np.random.seed() self.train_loader.dataset.update_lists() if self.verbose > 0: print(' ') print('Epoch {}/{}'.format(self.cur_epoch + 1, n_epochs)) print('Number of training examples given new list: {}'.format( len(self.train_loader.dataset))) train_iter = tqdm(enumerate(self.train_loader), total=len(self.train_loader)) else: train_iter = enumerate(self.train_loader) train_loss_epoch = 0.0 if self.softmax and not self.pretrain: ce_epoch = 0.0 for t, batch in train_iter: train_loss, ce = self.train_step(batch) self.history['train_loss_batch'].append(train_loss) self.history['softmax_batch'].append(ce) train_loss_epoch += train_loss ce_epoch += ce if self.logger: self.logger.add_scalar('Train/Train Loss', train_loss, self.total_iters) self.logger.add_scalar('Train/Triplet Loss', train_loss - ce, self.total_iters) self.logger.add_scalar('Train/Cross enropy', ce, self.total_iters) self.logger.add_scalar( 'Info/LR', self.optimizer.optimizer.param_groups[0]['lr'], self.total_iters) self.total_iters += 1 self.history['train_loss'].append(train_loss_epoch / (t + 1)) self.history['softmax'].append(ce_epoch / (t + 1)) if self.verbose > 0: print( 'Total train loss, Triplet loss, and Cross-entropy: {:0.4f}, {:0.4f}, {:0.4f}' .format(self.history['train_loss'][-1], (self.history['train_loss'][-1] - self.history['softmax'][-1]), self.history['softmax'][-1])) elif self.pretrain: ce_epoch = 0.0 for t, batch in train_iter: ce = self.pretrain_step(batch) self.history['train_loss_batch'].append(ce) ce_epoch += ce if self.logger: self.logger.add_scalar('Cross enropy', ce, self.total_iters) self.logger.add_scalar( 'Info/LR', self.optimizer.optimizer.param_groups[0]['lr'], self.total_iters) self.total_iters += 1 self.history['train_loss'].append(ce_epoch / (t + 1)) if self.verbose > 0: print('Train loss: {:0.4f}'.format( self.history['train_loss'][-1])) else: for t, batch in train_iter: train_loss = self.train_step(batch) self.history['train_loss_batch'].append(train_loss) train_loss_epoch += train_loss if self.logger: self.logger.add_scalar('Train/Train Loss', train_loss, self.total_iters) self.logger.add_scalar( 'Info/LR', self.optimizer.optimizer.param_groups[0]['lr'], self.total_iters) self.total_iters += 1 self.history['train_loss'].append(train_loss_epoch / (t + 1)) if self.verbose > 0: print('Total train loss, {:0.4f}'.format( self.history['train_loss'][-1])) if self.valid_loader is not None: scores, labels, emb, y_ = None, None, None, None for t, batch in enumerate(self.valid_loader): scores_batch, labels_batch, emb_batch, y_batch = self.valid( batch) try: scores = np.concatenate([scores, scores_batch], 0) labels = np.concatenate([labels, labels_batch], 0) emb = np.concatenate([emb, emb_batch], 0) y_ = np.concatenate([y_, y_batch], 0) except: scores, labels, emb, y_ = scores_batch, labels_batch, emb_batch, y_batch self.history['valid_loss'].append(compute_eer(labels, scores)) if self.verbose > 0: print( 'Current validation loss, best validation loss, and epoch: {:0.4f}, {:0.4f}, {}' .format(self.history['valid_loss'][-1], np.min(self.history['valid_loss']), 1 + np.argmin(self.history['valid_loss']))) if self.logger: self.logger.add_scalar('Valid/EER', self.history['valid_loss'][-1], self.total_iters - 1) self.logger.add_scalar('Valid/Best EER', np.min(self.history['valid_loss']), self.total_iters - 1) self.logger.add_pr_curve('Valid. ROC', labels=labels, predictions=scores, global_step=self.total_iters - 1) if emb.shape[0] > 20000: idxs = np.random.choice(np.arange(emb.shape[0]), size=20000, replace=False) emb, y_ = emb[idxs, :], y_[idxs] self.logger.add_histogram('Valid/Embeddings', values=emb, global_step=self.total_iters - 1) self.logger.add_histogram('Valid/Scores', values=scores, global_step=self.total_iters - 1) self.logger.add_histogram('Valid/Labels', values=labels, global_step=self.total_iters - 1) if self.verbose > 1: self.logger.add_embedding( mat=emb, metadata=list(y_), global_step=self.total_iters - 1) if self.verbose > 0: print('Current LR: {}'.format( self.optimizer.optimizer.param_groups[0]['lr'])) self.cur_epoch += 1 if self.valid_loader is not None and self.save_cp and ( self.cur_epoch % save_every == 0 or self.history['valid_loss'][-1] < np.min([np.inf] + self.history['valid_loss'][:-1])): self.checkpointing() elif self.save_cp and self.cur_epoch % save_every == 0: self.checkpointing() if self.verbose > 0: print('Training done!') if self.valid_loader is not None: if self.verbose > 0: print( 'Best validation loss and corresponding epoch: {:0.4f}, {}' .format(np.min(self.history['valid_loss']), 1 + np.argmin(self.history['valid_loss']))) return np.min(self.history['valid_loss']) else: return np.min(self.history['train_loss']) def train_step(self, batch): self.model.train() self.optimizer.zero_grad() utt_1, utt_2, utt_3, utt_4, utt_5, y = batch utterances = torch.cat([utt_1, utt_2, utt_3, utt_4, utt_5], dim=0) y = torch.cat(5 * [y], dim=0).squeeze().contiguous() entropy_indices = None ridx = np.random.randint(utterances.size(3) // 4, utterances.size(3)) utterances = utterances[:, :, :, :ridx].contiguous() if self.cuda_mode: utterances = utterances.to(self.device, non_blocking=True) y = y.to(self.device, non_blocking=True) out, embeddings = self.model.forward(utterances) embeddings_norm = F.normalize(embeddings, p=2, dim=1) out_norm = F.normalize(out, p=2, dim=1) if self.mining: triplets_idx, entropy_indices = self.harvester_mine.get_triplets( embeddings_norm.detach(), y) else: triplets_idx = self.harvester_all.get_triplets( embeddings_norm.detach(), y) if self.cuda_mode: triplets_idx = triplets_idx.to(self.device, non_blocking=True) emb_a = torch.index_select(embeddings_norm, 0, triplets_idx[:, 0]) emb_p = torch.index_select(embeddings_norm, 0, triplets_idx[:, 1]) emb_n = torch.index_select(embeddings_norm, 0, triplets_idx[:, 2]) loss = self.triplet_loss(emb_a, emb_p, emb_n) loss_log = loss.item() if entropy_indices is not None: entropy_regularizer = torch.log( torch.nn.functional.pairwise_distance( embeddings_norm, embeddings_norm[entropy_indices, :]) + 1e-6).mean() loss -= entropy_regularizer * self.lambda_ if self.logger: self.logger.add_scalar('Train/Entropy reg.', entropy_regularizer.item(), self.total_iters) if self.softmax: ce = self.ce_criterion(self.model.out_proj(out_norm, y), y) loss += ce loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_gnorm) self.optimizer.step() if self.logger: self.logger.add_scalar('Info/Grad_norm', grad_norm, self.total_iters) return loss_log + ce.item(), ce.item() else: loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_gnorm) self.optimizer.step() if self.logger: self.logger.add_scalar('Info/Grad_norm', grad_norm, self.total_iters) return loss_log def pretrain_step(self, batch): self.model.train() self.optimizer.zero_grad() utt_1, utt_2, utt_3, utt_4, utt_5, y = batch utterances = torch.cat([utt_1, utt_2, utt_3, utt_4, utt_5], dim=0) y = torch.cat(5 * [y], dim=0).squeeze().contiguous() ridx = np.random.randint(utterances.size(3) // 4, utterances.size(3)) utterances = utterances[:, :, :, :ridx].contiguous() if self.cuda_mode: utterances = utterances.to(self.device, non_blocking=True) y = y.to(self.device, non_blocking=True) out, embeddings = self.model.forward(utterances) out_norm = F.normalize(out, p=2, dim=1) loss = self.ce_criterion(self.model.out_proj(out_norm, y), y) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_gnorm) self.optimizer.step() if self.logger: self.logger.add_scalar('Info/Grad_norm', grad_norm, self.total_iters) return loss.item() def valid(self, batch): self.model.eval() with torch.no_grad(): utt_1, utt_2, utt_3, utt_4, utt_5, y = batch utterances = torch.cat([utt_1, utt_2, utt_3, utt_4, utt_5], dim=0) y = torch.cat(5 * [y], dim=0).squeeze().contiguous() ridx = np.random.randint( utterances.size(3) // 4, utterances.size(3)) utterances = utterances[:, :, :, :ridx].contiguous() if self.cuda_mode: utterances = utterances.to(self.device, non_blocking=True) y = y.to(self.device, non_blocking=True) out, embeddings = self.model.forward(utterances) embeddings_norm = F.normalize(embeddings, p=2, dim=1) triplets_idx = self.harvester_all.get_triplets( embeddings_norm.detach(), y) if self.cuda_mode: triplets_idx = triplets_idx.cuda(self.device) emb_a = torch.index_select(embeddings_norm, 0, triplets_idx[:, 0]) emb_p = torch.index_select(embeddings_norm, 0, triplets_idx[:, 1]) emb_n = torch.index_select(embeddings_norm, 0, triplets_idx[:, 2]) scores_p = torch.nn.functional.cosine_similarity(emb_a, emb_p) scores_n = torch.nn.functional.cosine_similarity(emb_a, emb_n) return np.concatenate([ scores_p.detach().cpu().numpy(), scores_n.detach().cpu().numpy() ], 0), np.concatenate( [np.ones(scores_p.size(0)), np.zeros(scores_n.size(0))], 0), embeddings.detach().cpu().numpy(), y.detach().cpu().numpy() def triplet_loss(self, emba, embp, embn, reduce_=True): loss_ = torch.nn.TripletMarginLoss( margin=self.margin, p=2.0, eps=1e-06, swap=self.swap, reduction='mean' if reduce_ else 'none')(emba, embp, embn) return loss_ def checkpointing(self): # Checkpointing if self.verbose > 0: print('Checkpointing...') ckpt = { 'model_state': self.model.state_dict(), 'optimizer_state': self.optimizer.optimizer.state_dict(), 'history': self.history, 'total_iters': self.total_iters, 'cur_epoch': self.cur_epoch } try: torch.save(ckpt, self.save_epoch_fmt.format(self.cur_epoch)) except: torch.save(ckpt, self.save_epoch_fmt) def load_checkpoint(self, ckpt): if os.path.isfile(ckpt): ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage) # Load model state self.model.load_state_dict(ckpt['model_state']) # Load optimizer state self.optimizer.optimizer.load_state_dict(ckpt['optimizer_state']) # Load history self.history = ckpt['history'] self.total_iters = ckpt['total_iters'] self.cur_epoch = ckpt['cur_epoch'] if self.cuda_mode: self.model = self.model.cuda(self.device) else: print('No checkpoint found at: {}'.format(ckpt)) def print_grad_norms(self): norm = 0.0 for params in list(self.model.parameters()): norm += params.grad.norm(2).item() print('Sum of grads norms: {}'.format(norm))
class TrainLoop(object): def __init__(self, model, optimizer, train_loader, valid_loader, margin, lambda_, patience, verbose=-1, device=0, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, swap=False, softmax=False, pretrain=False, mining=False, cuda=True): if checkpoint_path is None: # Save to current directory self.checkpoint_path = os.getcwd() else: self.checkpoint_path = checkpoint_path if not os.path.isdir(self.checkpoint_path): os.mkdir(self.checkpoint_path) self.save_epoch_fmt = os.path.join(self.checkpoint_path, cp_name) if cp_name else os.path.join(self.checkpoint_path, 'checkpoint_{}ep.pt') self.cuda_mode = cuda self.softmax = softmax!='none' self.pretrain = pretrain self.mining = mining self.model = model self.swap = swap self.lambda_ = lambda_ self.optimizer = optimizer self.train_loader = train_loader self.valid_loader = valid_loader self.total_iters = 0 self.cur_epoch = 0 self.margin = margin self.harvester = HardestNegativeTripletSelector(margin=self.margin, cpu=not self.cuda_mode) self.verbose = verbose self.save_cp = save_cp self.device = device self.history = {'train_loss': [], 'train_loss_batch': []} if self.valid_loader is not None: self.history['valid_loss'] = [] self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, patience=patience, verbose=True if self.verbose>0 else False, threshold=1e-4, min_lr=1e-7) else: self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[20, 100, 200, 300, 400], gamma=0.1) if self.softmax: self.history['softmax_batch']=[] self.history['softmax']=[] if checkpoint_epoch is not None: self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch)) def train(self, n_epochs=1, save_every=1): while (self.cur_epoch < n_epochs): np.random.seed() if self.verbose>0: print(' ') print('Epoch {}/{}'.format(self.cur_epoch+1, n_epochs)) train_iter = tqdm(enumerate(self.train_loader)) else: train_iter = enumerate(self.train_loader) train_loss_epoch=0.0 if self.softmax and not self.pretrain: ce_epoch=0.0 for t, batch in train_iter: train_loss, ce = self.train_step(batch) self.history['train_loss_batch'].append(train_loss) self.history['softmax_batch'].append(ce) train_loss_epoch+=train_loss ce_epoch+=ce self.total_iters += 1 self.history['train_loss'].append(train_loss_epoch/(t+1)) self.history['softmax'].append(ce_epoch/(t+1)) if self.verbose>0: print('Total train loss, Triplet loss, and Cross-entropy: {:0.4f}, {:0.4f}, {:0.4f}'.format(self.history['train_loss'][-1], (self.history['train_loss'][-1]-self.history['softmax'][-1]), self.history['softmax'][-1])) elif self.pretrain: ce_epoch=0.0 for t, batch in train_iter: ce = self.pretrain_step(batch) self.history['train_loss_batch'].append(ce) ce_epoch+=ce self.total_iters += 1 self.history['train_loss'].append(ce_epoch/(t+1)) if self.verbose>0: print('Train loss: {:0.4f}'.format(self.history['train_loss'][-1])) else: for t, batch in train_iter: train_loss = self.train_step(batch) self.history['train_loss_batch'].append(train_loss) train_loss_epoch+=train_loss self.total_iters += 1 self.history['train_loss'].append(train_loss_epoch/(t+1)) if self.verbose>0: print('Total train loss, {:0.4f}'.format(self.history['train_loss'][-1])) if self.valid_loader is not None: scores, labels = None, None for t, batch in enumerate(self.valid_loader): scores_batch, labels_batch = self.valid(batch) try: scores = np.concatenate([scores, scores_batch], 0) labels = np.concatenate([labels, labels_batch], 0) except: scores, labels = scores_batch, labels_batch self.history['valid_loss'].append(compute_eer(labels, scores)) if self.verbose>0: print('Current validation loss, best validation loss, and epoch: {:0.4f}, {:0.4f}, {}'.format(self.history['valid_loss'][-1], np.min(self.history['valid_loss']), 1+np.argmin(self.history['valid_loss']))) self.scheduler.step(self.history['valid_loss'][-1]) else: self.scheduler.step() if self.verbose>0: print('Current LR: {}'.format(self.optimizer.param_groups[0]['lr'])) self.cur_epoch += 1 if self.valid_loader is not None and self.save_cp and (self.cur_epoch % save_every == 0 or self.history['valid_loss'][-1] < np.min([np.inf]+self.history['valid_loss'][:-1])): self.checkpointing() elif self.save_cp and self.cur_epoch % save_every == 0: self.checkpointing() if self.verbose>0: print('Training done!') if self.valid_loader is not None: if self.verbose>0: print('Best validation loss and corresponding epoch: {:0.4f}, {}'.format(np.min(self.history['valid_loss']), 1+np.argmin(self.history['valid_loss']))) return np.min(self.history['valid_loss']) else: return np.min(self.history['train_loss']) def train_step(self, batch): self.model.train() self.optimizer.zero_grad() if self.mining: utterances, y = batch utterances.resize_(utterances.size(0)*utterances.size(1), utterances.size(2), utterances.size(3), utterances.size(4)) y.resize_(y.numel()) elif self.softmax: utt_a, utt_p, utt_n, y = batch else: utt_a, utt_p, utt_n = batch entropy_indices = None if self.mining: ridx = np.random.randint(utterances.size(3)//4, utterances.size(3)) utterances = utterances[:,:,:,:ridx] if self.cuda_mode: utterances = utterances.cuda(self.device) embeddings = self.model.forward(utterances) embeddings_norm = F.normalize(embeddings, p=2, dim=1) triplets_idx, entropy_indices = self.harvester.get_triplets(embeddings_norm.detach(), y) if self.cuda_mode: triplets_idx = triplets_idx.cuda(self.device) emb_a = torch.index_select(embeddings_norm, 0, triplets_idx[:, 0]) emb_p = torch.index_select(embeddings_norm, 0, triplets_idx[:, 1]) emb_n = torch.index_select(embeddings_norm, 0, triplets_idx[:, 2]) else: ridx = np.random.randint(utt_a.size(3)//4, utt_a.size(3)) utt_a, utt_p, utt_n = utt_a[:,:,:,:ridx], utt_p[:,:,:,:ridx], utt_n[:,:,:,:ridx] if self.cuda_mode: utt_a, utt_p, utt_n = utt_a.cuda(self.device), utt_p.cuda(self.device), utt_n.cuda(self.device) emb_a, emb_p, emb_n = self.model.forward(utt_a), self.model.forward(utt_p), self.model.forward(utt_n) emb_a = torch.div(emb_a, torch.norm(emb_a, 2, 1).unsqueeze(1).expand_as(emb_a)) emb_p = torch.div(emb_p, torch.norm(emb_p, 2, 1).unsqueeze(1).expand_as(emb_p)) emb_n = torch.div(emb_n, torch.norm(emb_n, 2, 1).unsqueeze(1).expand_as(emb_n)) embeddings_norm = F.normalize(emb_a, p=2, dim=1) loss = self.triplet_loss(emb_a, emb_p, emb_n) loss_log = loss.item() if entropy_indices is not None: entropy_regularizer = torch.nn.functional.pairwise_distance(embeddings_norm, embeddings_norm[entropy_indices,:]).mean() loss -= entropy_regularizer*self.lambda_ if self.softmax: if self.cuda_mode: y = y.cuda(self.device).squeeze() ce = F.cross_entropy(self.model.out_proj(embeddings_norm, y), y) loss += ce loss.backward() self.optimizer.step() return loss_log+ce.item(), ce.item() else: loss.backward() self.optimizer.step() return loss_log def pretrain_step(self, batch): self.model.train() self.optimizer.zero_grad() utt, y = batch ridx = np.random.randint(utt.size(3)//2, utt.size(3)) utt = utt[:,:,:,:ridx] if self.cuda_mode: utt, y = utt.cuda(self.device), y.cuda(self.device).squeeze() embeddings = self.model.forward(utt) embeddings_norm = F.normalize(embeddings, p=2, dim=1) loss = F.cross_entropy(self.model.out_proj(embeddings_norm, y), y) loss.backward() self.optimizer.step() return loss.item() def valid(self, batch): self.model.eval() with torch.no_grad(): xa, xp, xn = batch ridx = np.random.randint(xa.size(3)//2, xa.size(3)) xa, xp, xn = xa[:,:,:,:ridx], xp[:,:,:,:ridx], xn[:,:,:,:ridx] if self.cuda_mode: xa = xa.contiguous().cuda(self.device) xp = xp.contiguous().cuda(self.device) xn = xn.contiguous().cuda(self.device) emb_a = self.model.forward(xa) emb_p = self.model.forward(xp) emb_n = self.model.forward(xn) scores_p = torch.nn.functional.cosine_similarity(emb_a, emb_p) scores_n = torch.nn.functional.cosine_similarity(emb_a, emb_n) return np.concatenate([scores_p.detach().cpu().numpy(), scores_n.detach().cpu().numpy()], 0), np.concatenate([np.ones(scores_p.size(0)), np.zeros(scores_n.size(0))], 0) def triplet_loss(self, emba, embp, embn, reduce_=True): loss_ = torch.nn.TripletMarginLoss(margin=self.margin, p=2.0, eps=1e-06, swap=self.swap, reduction='mean' if reduce_ else 'none')(emba, embp, embn) return loss_ def checkpointing(self): # Checkpointing if self.verbose>0: print('Checkpointing...') ckpt = {'model_state': self.model.state_dict(), 'optimizer_state': self.optimizer.state_dict(), 'scheduler_state': self.scheduler.state_dict(), 'history': self.history, 'total_iters': self.total_iters, 'cur_epoch': self.cur_epoch} try: torch.save(ckpt, self.save_epoch_fmt.format(self.cur_epoch)) except: torch.save(ckpt, self.save_epoch_fmt) def load_checkpoint(self, ckpt): if os.path.isfile(ckpt): ckpt = torch.load(ckpt, map_location = lambda storage, loc: storage) # Load model state self.model.load_state_dict(ckpt['model_state']) # Load optimizer state self.optimizer.load_state_dict(ckpt['optimizer_state']) # Load scheduler state self.scheduler.load_state_dict(ckpt['scheduler_state']) # Load history self.history = ckpt['history'] self.total_iters = ckpt['total_iters'] self.cur_epoch = ckpt['cur_epoch'] if self.cuda_mode: self.model = self.model.cuda(self.device) else: print('No checkpoint found at: {}'.format(ckpt)) def print_grad_norms(self): norm = 0.0 for params in list(self.model.parameters()): norm+=params.grad.norm(2).item() print('Sum of grads norms: {}'.format(norm))