Ejemplo n.º 1
0
	def __init__(self, model, optimizer, train_loader, valid_loader, margin, lambda_, patience, verbose=-1, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, swap=False, cuda=True):
		if checkpoint_path is None:
			# Save to current directory
			self.checkpoint_path = os.getcwd()
		else:
			self.checkpoint_path = checkpoint_path
			if not os.path.isdir(self.checkpoint_path):
				os.mkdir(self.checkpoint_path)

		self.save_epoch_fmt = os.path.join(self.checkpoint_path, cp_name) if cp_name else os.path.join(self.checkpoint_path, 'checkpoint_{}ep.pt')
		self.cuda_mode = cuda
		self.model = model
		self.optimizer = optimizer
		self.train_loader = train_loader
		self.valid_loader = valid_loader
		self.history = {'train_loss': [], 'train_loss_batch': [], 'triplet_loss': [], 'triplet_loss_batch': [], 'ce_loss': [], 'ce_loss_batch': [],'ErrorRate': [], 'EER': []}
		self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, patience=patience, verbose=True if verbose>0 else False, threshold=1e-4, min_lr=1e-8)
		self.total_iters = 0
		self.cur_epoch = 0
		self.lambda_ = lambda_
		self.swap = swap
		self.margin = margin
		self.harvester = HardestNegativeTripletSelector(margin=0.1, cpu=not self.cuda_mode)
		self.harvester_val = AllTripletSelector()
		self.verbose = verbose
		self.save_cp = save_cp
		self.device = next(self.model.parameters()).device

		if checkpoint_epoch is not None:
			self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))
Ejemplo n.º 2
0
	def __init__(self, model, optimizer, train_loader, valid_loader, patience, verbose=-1, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, pretrain=False, cuda=True):
		if checkpoint_path is None:
			# Save to current directory
			self.checkpoint_path = os.getcwd()
		else:
			self.checkpoint_path = checkpoint_path
			if not os.path.isdir(self.checkpoint_path):
				os.mkdir(self.checkpoint_path)

		self.save_epoch_fmt = os.path.join(self.checkpoint_path, cp_name) if cp_name else os.path.join(self.checkpoint_path, 'checkpoint_{}ep.pt')
		self.cuda_mode = cuda
		self.pretrain = pretrain
		self.model = model
		self.optimizer = optimizer
		self.train_loader = train_loader
		self.valid_loader = valid_loader
		self.total_iters = 0
		self.cur_epoch = 0
		self.harvester = AllTripletSelector()
		self.verbose = verbose
		self.save_cp = save_cp
		self.device = next(self.model.parameters()).device
		self.history = {'train_loss': [], 'train_loss_batch': [], 'ce_loss': [], 'ce_loss_batch': [], 'bin_loss': [], 'bin_loss_batch': []}

		if self.valid_loader is not None:
			self.history['e2e_eer'] = []
			self.history['cos_eer'] = []
			self.history['ErrorRate'] = []
			self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, patience=patience, verbose=True if self.verbose>0 else False, threshold=1e-4, min_lr=1e-7)
		else:
			self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[20, 100, 200, 300, 400], gamma=0.1)

		if checkpoint_epoch is not None:
			self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))
Ejemplo n.º 3
0
    def __init__(self,
                 model,
                 optimizer,
                 train_loader,
                 valid_loader,
                 margin,
                 lambda_,
                 verbose=-1,
                 cp_name=None,
                 save_cp=False,
                 checkpoint_path=None,
                 checkpoint_epoch=None,
                 swap=False,
                 cuda=True):
        if checkpoint_path is None:
            # Save to current directory
            self.checkpoint_path = os.getcwd()
        else:
            self.checkpoint_path = checkpoint_path
            if not os.path.isdir(self.checkpoint_path):
                os.mkdir(self.checkpoint_path)

        self.save_epoch_fmt = os.path.join(
            self.checkpoint_path, cp_name) if cp_name else os.path.join(
                self.checkpoint_path, 'checkpoint_{}ep.pt')
        self.cuda_mode = cuda
        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.history = {
            'train_loss': [],
            'train_loss_batch': [],
            'triplet_loss': [],
            'triplet_loss_batch': [],
            'ce_loss': [],
            'ce_loss_batch': [],
            'EER': [],
            'acc_1': [],
            'acc_5': []
        }
        self.total_iters = 0
        self.cur_epoch = 0
        self.lambda_ = lambda_
        self.swap = swap
        self.margin = margin
        self.harvester = HardestNegativeTripletSelector(margin=0.1,
                                                        cpu=not self.cuda_mode)
        self.harvester_val = AllTripletSelector()
        self.verbose = verbose
        self.save_cp = save_cp
        self.device = next(self.model.parameters()).device
        self.base_lr = self.optimizer.param_groups[0]['lr']

        if checkpoint_epoch is not None:
            self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))
Ejemplo n.º 4
0
	def __init__(self, model, optimizer, train_loader, valid_loader, margin, lambda_, max_gnorm, patience, lr_factor, label_smoothing, verbose=-1, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, pretrain=False, swap=False, cuda=True, logger=None):
		if checkpoint_path is None:
			# Save to current directory
			self.checkpoint_path = os.getcwd()
		else:
			self.checkpoint_path = checkpoint_path
			if not os.path.isdir(self.checkpoint_path):
				os.mkdir(self.checkpoint_path)

		self.save_epoch_fmt = os.path.join(self.checkpoint_path, cp_name) if cp_name else os.path.join(self.checkpoint_path, 'checkpoint_{}ep.pt')
		self.cuda_mode = cuda
		self.pretrain = pretrain
		self.model = model
		self.optimizer = optimizer
		self.patience = patience
		self.max_gnorm = max_gnorm
		self.lr_factor = lr_factor
		self.lambda_ = lambda_
		self.swap = swap
		self.margin = margin
		self.train_loader = train_loader
		self.valid_loader = valid_loader
		self.total_iters = 0
		self.cur_epoch = 0
		self.harvester = HardestNegativeTripletSelector(margin=0.1, cpu=not self.cuda_mode)
		self.harvester_val = AllTripletSelector()
		self.verbose = verbose
		self.save_cp = save_cp
		self.device = next(self.model.parameters()).device
		self.history = {'train_loss': [], 'train_loss_batch': [], 'ce_loss': [], 'ce_loss_batch': [], 'triplet_loss': [], 'triplet_loss_batch': []}
		self.disc_label_smoothing = label_smoothing*0.5
		self.base_lr = self.optimizer.param_groups[0]['lr']
		self.logger = logger

		if label_smoothing>0.0:
			self.ce_criterion = LabelSmoothingLoss(label_smoothing, lbl_set_size=1000)
		else:
			self.ce_criterion = torch.nn.CrossEntropyLoss()

		if self.valid_loader is not None:
			self.history['cos_eer'] = []
			self.history['acc_1'] = []
			self.history['acc_5'] = []

		if checkpoint_epoch is not None:
			self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))
Ejemplo n.º 5
0
class TrainLoop(object):

	def __init__(self, model, optimizer, train_loader, valid_loader, patience, label_smoothing, verbose=-1, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, pretrain=False, cuda=True):
		if checkpoint_path is None:
			# Save to current directory
			self.checkpoint_path = os.getcwd()
		else:
			self.checkpoint_path = checkpoint_path
			if not os.path.isdir(self.checkpoint_path):
				os.mkdir(self.checkpoint_path)

		self.save_epoch_fmt = os.path.join(self.checkpoint_path, cp_name) if cp_name else os.path.join(self.checkpoint_path, 'checkpoint_{}ep.pt')
		self.cuda_mode = cuda
		self.pretrain = pretrain
		self.model = model
		self.optimizer = optimizer
		self.train_loader = train_loader
		self.valid_loader = valid_loader
		self.total_iters = 0
		self.cur_epoch = 0
		self.harvester = AllTripletSelector()
		self.verbose = verbose
		self.save_cp = save_cp
		self.device = next(self.model.parameters()).device
		self.history = {'train_loss': [], 'train_loss_batch': [], 'ce_loss': [], 'ce_loss_batch': [], 'bin_loss': [], 'bin_loss_batch': []}
		self.disc_label_smoothing = label_smoothing*0.5

		if label_smoothing>0.0:
			self.ce_criterion = LabelSmoothingLoss(label_smoothing, lbl_set_size=10)
		else:
			self.ce_criterion = torch.nn.CrossEntropyLoss()

		if self.valid_loader is not None:
			self.history['e2e_eer'] = []
			self.history['cos_eer'] = []
			self.history['ErrorRate'] = []
			self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, patience=patience, verbose=True if self.verbose>0 else False, threshold=1e-4, min_lr=1e-7)
		else:
			self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[20, 100, 200, 300, 400], gamma=0.1)

		if checkpoint_epoch is not None:
			self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))

	def train(self, n_epochs=1, save_every=1):

		while (self.cur_epoch < n_epochs):

			np.random.seed()

			if self.verbose>0:
				print(' ')
				print('Epoch {}/{}'.format(self.cur_epoch+1, n_epochs))
				train_iter = tqdm(enumerate(self.train_loader))
			else:
				train_iter = enumerate(self.train_loader)

			if self.pretrain:

				ce_epoch=0.0
				for t, batch in train_iter:
					ce = self.pretrain_step(batch)
					self.history['train_loss_batch'].append(ce)
					ce_epoch+=ce
					self.total_iters += 1

				self.history['train_loss'].append(ce_epoch/(t+1))

				if self.verbose>0:
					print('Train loss: {:0.4f}'.format(self.history['train_loss'][-1]))

			else:

				train_loss_epoch=0.0
				ce_loss_epoch=0.0
				bin_loss_epoch=0.0
				for t, batch in train_iter:
					train_loss, ce_loss, bin_loss = self.train_step(batch)
					self.history['train_loss_batch'].append(train_loss)
					self.history['ce_loss_batch'].append(ce_loss)
					self.history['bin_loss_batch'].append(bin_loss)
					train_loss_epoch+=train_loss
					ce_loss_epoch+=ce_loss
					bin_loss_epoch+=bin_loss
					self.total_iters += 1

				self.history['train_loss'].append(train_loss_epoch/(t+1))
				self.history['ce_loss'].append(ce_loss_epoch/(t+1))
				self.history['bin_loss'].append(bin_loss_epoch/(t+1))

				if self.verbose>0:
					print(' ')
					print('Total train loss: {:0.4f}'.format(self.history['train_loss'][-1]))
					print('CE loss: {:0.4f}'.format(self.history['ce_loss'][-1]))
					print('Binary classification loss: {:0.4f}'.format(self.history['bin_loss'][-1]))
					print(' ')

			if self.valid_loader is not None:

				tot_correct, tot_ = 0, 0
				e2e_scores, cos_scores, labels = None, None, None

				for t, batch in enumerate(self.valid_loader):
					correct, total, e2e_scores_batch, cos_scores_batch, labels_batch = self.valid(batch)

					try:
						e2e_scores = np.concatenate([e2e_scores, e2e_scores_batch], 0)
						cos_scores = np.concatenate([cos_scores, cos_scores_batch], 0)
						labels = np.concatenate([labels, labels_batch], 0)
					except:
						e2e_scores, cos_scores, labels = e2e_scores_batch, cos_scores_batch, labels_batch

					tot_correct += correct
					tot_ += total

				self.history['e2e_eer'].append(compute_eer(labels, e2e_scores))
				self.history['cos_eer'].append(compute_eer(labels, cos_scores))
				self.history['ErrorRate'].append(1.-float(tot_correct)/tot_)

				if self.verbose>0:
					print(' ')
					print('Current e2e EER, best e2e EER, and epoch: {:0.4f}, {:0.4f}, {}'.format(self.history['e2e_eer'][-1], np.min(self.history['e2e_eer']), 1+np.argmin(self.history['e2e_eer'])))
					print('Current cos EER, best cos EER, and epoch: {:0.4f}, {:0.4f}, {}'.format(self.history['cos_eer'][-1], np.min(self.history['cos_eer']), 1+np.argmin(self.history['cos_eer'])))
					print('Current Error rate, best Error rate, and epoch: {:0.4f}, {:0.4f}, {}'.format(self.history['ErrorRate'][-1], np.min(self.history['ErrorRate']), 1+np.argmin(self.history['ErrorRate'])))

				self.scheduler.step(np.min([self.history['e2e_eer'][-1], self.history['cos_eer'][-1]]))

			else:
				self.scheduler.step()

			if self.verbose>0:
				print('Current LR: {}'.format(self.optimizer.param_groups[0]['lr']))

			self.cur_epoch += 1

			if self.valid_loader is not None and self.save_cp and (self.cur_epoch % save_every == 0 or self.history['e2e_eer'][-1] < np.min([np.inf]+self.history['e2e_eer'][:-1]) or self.history['cos_eer'][-1] < np.min([np.inf]+self.history['cos_eer'][:-1])):
					self.checkpointing()
			elif self.save_cp and self.cur_epoch % save_every == 0:
					self.checkpointing()

		if self.verbose>0:
			print('Training done!')

		if self.valid_loader is not None:
			if self.verbose>0:
				print('Best e2e eer and corresponding epoch: {:0.4f}, {}'.format(np.min(self.history['e2e_eer']), 1+np.argmin(self.history['e2e_eer'])))
				print('Best cos eer and corresponding epoch: {:0.4f}, {}'.format(np.min(self.history['cos_eer']), 1+np.argmin(self.history['cos_eer'])))

			return [np.min(self.history['e2e_eer']), np.min(self.history['cos_eer']), np.min(self.history['ErrorRate'])]
		else:
			return [np.min(self.history['train_loss'])]

	def train_step(self, batch):

		self.model.train()
		self.optimizer.zero_grad()

		x, y = batch

		x = x.to(self.device)
		y = y.to(self.device)

		embeddings = self.model.forward(x)

		embeddings_norm = F.normalize(embeddings, p=2, dim=1)

		ce_loss = self.ce_criterion(self.model.out_proj(embeddings_norm, y), y)

		# Get all triplets now for bin classifier
		triplets_idx = self.harvester.get_triplets(embeddings_norm.detach(), y)
		triplets_idx = triplets_idx.to(self.device)

		emb_a = torch.index_select(embeddings, 0, triplets_idx[:, 0])
		emb_p = torch.index_select(embeddings, 0, triplets_idx[:, 1])
		emb_n = torch.index_select(embeddings, 0, triplets_idx[:, 2])

		emb_ap = torch.cat([emb_a, emb_p],1)
		emb_an = torch.cat([emb_a, emb_n],1)
		emb_ = torch.cat([emb_ap, emb_an],0)

		y_ = torch.cat([torch.rand(emb_ap.size(0))*self.disc_label_smoothing+(1.0-self.disc_label_smoothing), torch.rand(emb_an.size(0))*self.disc_label_smoothing],0) if isinstance(self.ce_criterion, LabelSmoothingLoss) else torch.cat([torch.ones(emb_ap.size(0)), torch.zeros(emb_an.size(0))],0)
		y_ = y_.to(self.device)

		pred_bin = self.model.forward_bin(emb_).squeeze()

		loss_bin = torch.nn.BCELoss()(pred_bin, y_)

		loss = ce_loss + loss_bin
		loss.backward()
		self.optimizer.step()

		return loss.item(), ce_loss.item(), loss_bin.item()


	def pretrain_step(self, batch):

		self.model.train()
		self.optimizer.zero_grad()

		x, y = batch

		x, y = x.to(self.device), y.to(self.device).squeeze()

		embeddings = self.model.forward(utt)
		embeddings_norm = F.normalize(embeddings, p=2, dim=1)

		loss = F.cross_entropy(self.model.out_proj(embeddings_norm, y), y)

		loss.backward()
		self.optimizer.step()
		return loss.item()


	def valid(self, batch):

		self.model.eval()

		with torch.no_grad():

			x, y = batch

			x = x.to(self.device)
			y = y.to(self.device)

			embeddings = self.model.forward(x)

			embeddings_norm = F.normalize(embeddings, p=2, dim=1)

			out = self.model.out_proj(embeddings_norm, y)

			pred = F.softmax(out, dim=1).max(1)[1].long()
			correct = pred.squeeze().eq(y.squeeze()).detach().sum().item()

			# Get all triplets now for bin classifier
			triplets_idx = self.harvester.get_triplets(embeddings_norm.detach(), y)
			triplets_idx = triplets_idx.to(self.device)

			emb_a = torch.index_select(embeddings, 0, triplets_idx[:, 0])
			emb_p = torch.index_select(embeddings, 0, triplets_idx[:, 1])
			emb_n = torch.index_select(embeddings, 0, triplets_idx[:, 2])

			emb_ap = torch.cat([emb_a, emb_p],1)
			emb_an = torch.cat([emb_a, emb_n],1)

			e2e_scores_p = self.model.forward_bin(emb_ap).squeeze()
			e2e_scores_n = self.model.forward_bin(emb_an).squeeze()
			cos_scores_p = torch.nn.functional.cosine_similarity(emb_a, emb_p)
			cos_scores_n = torch.nn.functional.cosine_similarity(emb_a, emb_n)

		return correct, x.size(0), np.concatenate([e2e_scores_p.detach().cpu().numpy(), e2e_scores_n.detach().cpu().numpy()], 0), np.concatenate([cos_scores_p.detach().cpu().numpy(), cos_scores_n.detach().cpu().numpy()], 0), np.concatenate([np.ones(e2e_scores_p.size(0)), np.zeros(e2e_scores_n.size(0))], 0)

	def checkpointing(self):

		# Checkpointing
		if self.verbose>0:
			print('Checkpointing...')
		ckpt = {'model_state': self.model.state_dict(),
		'dropout_prob': self.model.dropout_prob,
		'n_hidden': self.model.n_hidden,
		'hidden_size': self.model.hidden_size,
		'sm_type': self.model.sm_type,
		'optimizer_state': self.optimizer.state_dict(),
		'scheduler_state': self.scheduler.state_dict(),
		'history': self.history,
		'total_iters': self.total_iters,
		'cur_epoch': self.cur_epoch}
		try:
			torch.save(ckpt, self.save_epoch_fmt.format(self.cur_epoch))
		except:
			torch.save(ckpt, self.save_epoch_fmt)

	def load_checkpoint(self, ckpt):

		if os.path.isfile(ckpt):

			ckpt = torch.load(ckpt, map_location = lambda storage, loc: storage)
			# Load model state
			self.model.load_state_dict(ckpt['model_state'])
			# Load optimizer state
			self.optimizer.load_state_dict(ckpt['optimizer_state'])
			# Load scheduler state
			self.scheduler.load_state_dict(ckpt['scheduler_state'])
			# Load history
			self.history = ckpt['history']
			self.total_iters = ckpt['total_iters']
			self.cur_epoch = ckpt['cur_epoch']
			if self.cuda_mode:
				self.model = self.model.cuda(self.device)

		else:
			print('No checkpoint found at: {}'.format(ckpt))

	def print_grad_norms(self):
		norm = 0.0
		for params in list(self.model.parameters()):
			norm+=params.grad.norm(2).item()
		print('Sum of grads norms: {}'.format(norm))
Ejemplo n.º 6
0
class TrainLoop(object):
    def __init__(self,
                 model,
                 optimizer,
                 train_loader,
                 valid_loader,
                 margin,
                 lambda_,
                 patience,
                 verbose=-1,
                 cp_name=None,
                 save_cp=False,
                 checkpoint_path=None,
                 checkpoint_epoch=None,
                 swap=False,
                 cuda=True):
        if checkpoint_path is None:
            # Save to current directory
            self.checkpoint_path = os.getcwd()
        else:
            self.checkpoint_path = checkpoint_path
            if not os.path.isdir(self.checkpoint_path):
                os.mkdir(self.checkpoint_path)

        self.save_epoch_fmt = os.path.join(
            self.checkpoint_path, cp_name) if cp_name else os.path.join(
                self.checkpoint_path, 'checkpoint_{}ep.pt')
        self.cuda_mode = cuda
        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.history = {
            'train_loss': [],
            'train_loss_batch': [],
            'triplet_loss': [],
            'triplet_loss_batch': [],
            'ce_loss': [],
            'ce_loss_batch': [],
            'ErrorRate': [],
            'EER': []
        }
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            factor=0.5,
            patience=patience,
            verbose=True if verbose > 0 else False,
            threshold=1e-4,
            min_lr=1e-8)
        self.total_iters = 0
        self.cur_epoch = 0
        self.lambda_ = lambda_
        self.swap = swap
        self.margin = margin
        self.harvester = HardestNegativeTripletSelector(margin=0.1,
                                                        cpu=not self.cuda_mode)
        self.harvester_val = AllTripletSelector()
        self.verbose = verbose
        self.save_cp = save_cp
        self.device = next(self.model.parameters()).device

        if checkpoint_epoch is not None:
            self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))

    def train(self, n_epochs=1, save_every=1):

        while self.cur_epoch < n_epochs:

            np.random.seed()

            if self.verbose > 0:
                print(' ')
                print('Epoch {}/{}'.format(self.cur_epoch + 1, n_epochs))
                train_iter = tqdm(enumerate(self.train_loader))
            else:
                train_iter = enumerate(self.train_loader)

            ce = 0.0
            triplet_loss = 0.0
            train_loss = 0.0

            # Train step
            for t, batch in train_iter:
                ce_batch, triplet_loss_batch = self.train_step(batch)
                ce += ce_batch
                triplet_loss += triplet_loss_batch
                train_loss += ce_batch + triplet_loss_batch
                self.history['train_loss_batch'].append(ce_batch +
                                                        triplet_loss_batch)
                self.history['triplet_loss_batch'].append(triplet_loss_batch)
                self.history['ce_loss_batch'].append(ce_batch)
                self.total_iters += 1

            self.history['train_loss'].append(train_loss / (t + 1))
            self.history['triplet_loss'].append(triplet_loss / (t + 1))
            self.history['ce_loss'].append(ce / (t + 1))

            if self.verbose > 0:
                print(' ')
                print(
                    'Total train loss, Triplet loss, and Cross-entropy: {:0.4f}, {:0.4f}, {:0.4f}'
                    .format(self.history['train_loss'][-1],
                            self.history['triplet_loss'][-1],
                            self.history['ce_loss'][-1]))

            # Validation

            tot_correct = 0
            tot_ = 0
            scores, labels = None, None

            for t, batch in enumerate(self.valid_loader):

                correct, total, scores_batch, labels_batch = self.valid(batch)

                try:
                    scores = np.concatenate([scores, scores_batch], 0)
                    labels = np.concatenate([labels, labels_batch], 0)
                except:
                    scores, labels = scores_batch, labels_batch

                tot_correct += correct
                tot_ += total

            self.history['EER'].append(compute_eer(labels, scores))
            self.history['ErrorRate'].append(1. - float(tot_correct) / tot_)

            if self.verbose > 0:
                print(' ')
                print(
                    'Current, best validation error rate, and epoch: {:0.4f}, {:0.4f}, {}'
                    .format(self.history['ErrorRate'][-1],
                            np.min(self.history['ErrorRate']),
                            1 + np.argmin(self.history['ErrorRate'])))

                print(' ')
                print(
                    'Current, best validation EER, and epoch: {:0.4f}, {:0.4f}, {}'
                    .format(self.history['EER'][-1],
                            np.min(self.history['EER']),
                            1 + np.argmin(self.history['EER'])))

            self.scheduler.step(self.history['ErrorRate'][-1])

            if self.verbose > 0:
                print(' ')
                print('Current LR: {}'.format(
                    self.optimizer.param_groups[0]['lr']))

            if self.save_cp and (
                    self.cur_epoch % save_every == 0 or
                (self.history['ErrorRate'][-1] <
                 np.min([np.inf] + self.history['ErrorRate'][:-1])) or
                (self.history['EER'][-1] <
                 np.min([np.inf] + self.history['EER'][:-1]))):
                self.checkpointing()

            self.cur_epoch += 1

        if self.verbose > 0:
            print('Training done!')

            if self.valid_loader is not None:
                print('Best error rate and corresponding epoch: {:0.4f}, {}'.
                      format(np.min(self.history['ErrorRate']),
                             1 + np.argmin(self.history['ErrorRate'])))
                print('Best EER and corresponding epoch: {:0.4f}, {}'.format(
                    np.min(self.history['EER']),
                    1 + np.argmin(self.history['EER'])))

        return np.min(self.history['ErrorRate'])

    def train_step(self, batch):

        self.model.train()

        self.optimizer.zero_grad()

        x, y = batch

        if self.cuda_mode:
            x = x.to(self.device)
            y = y.to(self.device)

        #x = x.view(x.size(0)*x.size(1), x.size(2), x.size(3), x.size(4))
        #y = y.view(y.size(0)*y.size(1))

        embeddings = self.model.forward(x)

        embeddings = torch.div(
            embeddings,
            torch.norm(embeddings, 2, 1).unsqueeze(1).expand_as(embeddings))
        embeddings_norm = F.normalize(embeddings, p=2, dim=1)

        loss_class = torch.nn.CrossEntropyLoss()(self.model.out_proj(
            embeddings_norm, y), y)

        triplets_idx, entropy_indices = self.harvester.get_triplets(
            embeddings_norm.detach(), y)

        if self.cuda_mode:
            triplets_idx = triplets_idx.to(self.device)

        emb_a = torch.index_select(embeddings_norm, 0, triplets_idx[:, 0])
        emb_p = torch.index_select(embeddings_norm, 0, triplets_idx[:, 1])
        emb_n = torch.index_select(embeddings_norm, 0, triplets_idx[:, 2])

        loss_metric = self.triplet_loss(emb_a, emb_p, emb_n)

        loss = loss_class + loss_metric

        entropy_regularizer = torch.nn.functional.pairwise_distance(
            embeddings_norm, embeddings_norm[entropy_indices, :]).mean()
        loss -= entropy_regularizer * self.lambda_

        loss.backward()

        self.optimizer.step()

        return loss_class.item(), loss_metric.item()

    def valid(self, batch):

        self.model.eval()

        x, y = batch

        if self.cuda_mode:
            x = x.to(self.device)
            y = y.to(self.device)

        with torch.no_grad():

            embeddings = self.model.forward(x)
            embeddings_norm = F.normalize(embeddings, p=2, dim=1)
            out = self.model.out_proj(embeddings_norm, y)

            pred = F.softmax(out, dim=1).max(1)[1].long()
            correct = pred.squeeze().eq(y.squeeze()).detach().sum().item()

            triplets_idx = self.harvester_val.get_triplets(embeddings, y)

            embeddings = embeddings.cpu()

            emb_a = torch.index_select(embeddings, 0, triplets_idx[:, 0])
            emb_p = torch.index_select(embeddings, 0, triplets_idx[:, 1])
            emb_n = torch.index_select(embeddings, 0, triplets_idx[:, 2])

            scores_p = F.cosine_similarity(emb_a, emb_p)
            scores_n = F.cosine_similarity(emb_a, emb_n)

        return correct, x.size(0), np.concatenate(
            [scores_p.detach().cpu().numpy(),
             scores_n.detach().cpu().numpy()], 0), np.concatenate(
                 [np.ones(scores_p.size(0)),
                  np.zeros(scores_n.size(0))], 0)

    def triplet_loss(self, emba, embp, embn, reduce_=True):

        loss_ = torch.nn.TripletMarginLoss(
            margin=self.margin,
            p=2.0,
            eps=1e-06,
            swap=self.swap,
            reduction='mean' if reduce_ else 'none')(emba, embp, embn)

        return loss_

    def checkpointing(self):

        # Checkpointing
        if self.verbose > 0:
            print(' ')
            print('Checkpointing...')
        ckpt = {
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'scheduler_state': self.scheduler.state_dict(),
            'history': self.history,
            'total_iters': self.total_iters,
            'cur_epoch': self.cur_epoch
        }

        try:
            torch.save(ckpt, self.save_epoch_fmt.format(self.cur_epoch))
        except:
            torch.save(ckpt, self.save_epoch_fmt)

    def load_checkpoint(self, ckpt):

        if os.path.isfile(ckpt):

            ckpt = torch.load(ckpt)
            # Load model state
            self.model.load_state_dict(ckpt['model_state'])
            # Load optimizer state
            self.optimizer.load_state_dict(ckpt['optimizer_state'])
            # Load scheduler state
            self.scheduler.load_state_dict(ckpt['scheduler_state'])
            # Load history
            self.history = ckpt['history']
            self.total_iters = ckpt['total_iters']
            self.cur_epoch = ckpt['cur_epoch']

        else:
            print('No checkpoint found at: {}'.format(ckpt))

    def print_grad_norms(self):
        norm = 0.0
        for params in list(self.model.parameters()):
            norm += params.grad.norm(2).data[0]
        print('Sum of grads norms: {}'.format(norm))

    def check_nans(self):
        for params in list(self.model.parameters()):
            if np.any(np.isnan(params.data.cpu().numpy())):
                print('params NANs!!!!!')
            if np.any(np.isnan(params.grad.data.cpu().numpy())):
                print('grads NANs!!!!!!')

    def initialize_params(self):
        for layer in self.model.modules():
            if isinstance(layer, torch.nn.Conv2d):
                init.kaiming_normal(layer.weight.data)
            elif isinstance(layer, torch.nn.BatchNorm2d):
                layer.weight.data.fill_(1)
                layer.bias.data.zero_()
Ejemplo n.º 7
0
class TrainLoop(object):
    def __init__(self,
                 model,
                 optimizer,
                 train_loader,
                 valid_loader,
                 slack,
                 train_mode,
                 patience,
                 verbose=-1,
                 cp_name=None,
                 save_cp=False,
                 checkpoint_path=None,
                 checkpoint_epoch=None,
                 cuda=True):
        if checkpoint_path is None:
            # Save to current directory
            self.checkpoint_path = os.getcwd()
        else:
            self.checkpoint_path = checkpoint_path
            if not os.path.isdir(self.checkpoint_path):
                os.mkdir(self.checkpoint_path)

        self.save_epoch_fmt = os.path.join(
            self.checkpoint_path, cp_name) if cp_name else os.path.join(
                self.checkpoint_path, 'checkpoint_{}ep.pt')
        self.cuda_mode = cuda
        self.model = model
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.history = {
            'train_loss': [],
            'train_loss_batch': [],
            'ErrorRate': [],
            'EER': []
        }
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            factor=0.5,
            patience=patience,
            verbose=True if verbose > 0 else False,
            threshold=1e-4,
            min_lr=1e-8)
        self.total_iters = 0
        self.cur_epoch = 0
        self.slack = slack
        self.train_mode = train_mode
        self.harvester_val = AllTripletSelector()
        self.verbose = verbose
        self.save_cp = save_cp
        self.device = next(self.model.parameters()).device

        if checkpoint_epoch is not None:
            self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))

    def train(self, n_epochs=1, save_every=1):

        while self.cur_epoch < n_epochs:

            np.random.seed()

            if self.verbose > 0:
                print(' ')
                print('Epoch {}/{}'.format(self.cur_epoch + 1, n_epochs))
                train_iter = tqdm(enumerate(self.train_loader))
            else:
                train_iter = enumerate(self.train_loader)

            train_loss = 0.0

            # Train step
            for t, batch in train_iter:
                train_loss_batch = self.train_step(batch)
                self.history['train_loss_batch'].append(train_loss_batch)
                train_loss += train_loss_batch
                self.total_iters += 1

            self.history['train_loss'].append(train_loss / (t + 1))

            if self.verbose > 0:
                print(' ')
                print('Total train loss: {:0.4f}'.format(
                    self.history['train_loss'][-1]))

            # Validation

            tot_correct = 0
            tot_ = 0
            scores, labels = None, None

            for t, batch in enumerate(self.valid_loader):

                correct, total, scores_batch, labels_batch = self.valid(batch)

                try:
                    scores = np.concatenate([scores, scores_batch], 0)
                    labels = np.concatenate([labels, labels_batch], 0)
                except:
                    scores, labels = scores_batch, labels_batch

                tot_correct += correct
                tot_ += total

            self.history['EER'].append(compute_eer(labels, scores))
            self.history['ErrorRate'].append(1. - float(tot_correct) / tot_)

            if self.verbose > 0:
                print(' ')
                print(
                    'Current, best validation error rate, and epoch: {:0.4f}, {:0.4f}, {}'
                    .format(self.history['ErrorRate'][-1],
                            np.min(self.history['ErrorRate']),
                            1 + np.argmin(self.history['ErrorRate'])))

                print(' ')
                print(
                    'Current, best validation EER, and epoch: {:0.4f}, {:0.4f}, {}'
                    .format(self.history['EER'][-1],
                            np.min(self.history['EER']),
                            1 + np.argmin(self.history['EER'])))

            self.scheduler.step(self.history['ErrorRate'][-1])

            if self.verbose > 0:
                print(' ')
                print('Current LR: {}'.format(
                    self.optimizer.param_groups[0]['lr']))

            if self.save_cp and (
                    self.cur_epoch % save_every == 0 or
                (self.history['ErrorRate'][-1] <
                 np.min([np.inf] + self.history['ErrorRate'][:-1])) or
                (self.history['EER'][-1] <
                 np.min([np.inf] + self.history['EER'][:-1]))):
                self.checkpointing()

            self.cur_epoch += 1

        if self.verbose > 0:
            print('Training done!')

            if self.valid_loader is not None:
                print('Best error rate and corresponding epoch: {:0.4f}, {}'.
                      format(np.min(self.history['ErrorRate']),
                             1 + np.argmin(self.history['ErrorRate'])))
                print('Best EER and corresponding epoch: {:0.4f}, {}'.format(
                    np.min(self.history['EER']),
                    1 + np.argmin(self.history['EER'])))

        return np.min(self.history['ErrorRate'])

    def train_step(self, batch):

        self.model.train()

        self.optimizer.zero_grad()

        x, y = batch

        if self.cuda_mode:
            x = x.to(self.device)
            y = y.to(self.device)

        embeddings = self.model.forward(x)

        loss = torch.nn.CrossEntropyLoss(
            reduction='none' if self.train_mode == 'hyper' else 'mean')(
                self.model.out_proj(embeddings), y)

        if self.train_mode == 'hyper':
            eta = self.slack * loss.detach().max().item() + 1e-6
            loss = -torch.log(eta - loss).sum()

        loss.backward()

        self.optimizer.step()

        return loss.item()

    def valid(self, batch):

        self.model.eval()

        x, y = batch

        if self.cuda_mode:
            x = x.to(self.device)
            y = y.to(self.device)

        with torch.no_grad():

            embeddings = self.model.forward(x)
            out = self.model.out_proj(embeddings)

            pred = F.softmax(out, dim=1).max(1)[1].long()
            correct = pred.squeeze().eq(y.squeeze()).detach().sum().item()

            triplets_idx = self.harvester_val.get_triplets(embeddings, y)

            embeddings = embeddings.cpu()

            emb_a = torch.index_select(embeddings, 0, triplets_idx[:, 0])
            emb_p = torch.index_select(embeddings, 0, triplets_idx[:, 1])
            emb_n = torch.index_select(embeddings, 0, triplets_idx[:, 2])

            scores_p = F.cosine_similarity(emb_a, emb_p)
            scores_n = F.cosine_similarity(emb_a, emb_n)

        return correct, x.size(0), np.concatenate(
            [scores_p.detach().cpu().numpy(),
             scores_n.detach().cpu().numpy()], 0), np.concatenate(
                 [np.ones(scores_p.size(0)),
                  np.zeros(scores_n.size(0))], 0)

    def checkpointing(self):

        # Checkpointing
        if self.verbose > 0:
            print(' ')
            print('Checkpointing...')
        ckpt = {
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'scheduler_state': self.scheduler.state_dict(),
            'history': self.history,
            'total_iters': self.total_iters,
            'cur_epoch': self.cur_epoch
        }

        try:
            torch.save(ckpt, self.save_epoch_fmt.format(self.cur_epoch))
        except:
            torch.save(ckpt, self.save_epoch_fmt)

    def load_checkpoint(self, ckpt):

        if os.path.isfile(ckpt):

            ckpt = torch.load(ckpt)
            # Load model state
            self.model.load_state_dict(ckpt['model_state'])
            # Load optimizer state
            self.optimizer.load_state_dict(ckpt['optimizer_state'])
            # Load scheduler state
            self.scheduler.load_state_dict(ckpt['scheduler_state'])
            # Load history
            self.history = ckpt['history']
            self.total_iters = ckpt['total_iters']
            self.cur_epoch = ckpt['cur_epoch']

        else:
            print('No checkpoint found at: {}'.format(ckpt))

    def print_grad_norms(self):
        norm = 0.0
        for params in list(self.model.parameters()):
            norm += params.grad.norm(2).data[0]
        print('Sum of grads norms: {}'.format(norm))

    def check_nans(self):
        for params in list(self.model.parameters()):
            if np.any(np.isnan(params.data.cpu().numpy())):
                print('params NANs!!!!!')
            if np.any(np.isnan(params.grad.data.cpu().numpy())):
                print('grads NANs!!!!!!')
Ejemplo n.º 8
0
    def __init__(self,
                 model,
                 optimizer,
                 train_loader,
                 valid_loader,
                 max_gnorm,
                 label_smoothing,
                 verbose=-1,
                 cp_name=None,
                 save_cp=False,
                 checkpoint_path=None,
                 checkpoint_epoch=None,
                 pretrain=False,
                 ablation=False,
                 cuda=True,
                 logger=None):
        if checkpoint_path is None:
            # Save to current directory
            self.checkpoint_path = os.getcwd()
        else:
            self.checkpoint_path = checkpoint_path
            if not os.path.isdir(self.checkpoint_path):
                os.mkdir(self.checkpoint_path)

        self.save_epoch_fmt = os.path.join(
            self.checkpoint_path, cp_name) if cp_name else os.path.join(
                self.checkpoint_path, 'checkpoint_{}ep.pt')
        self.cuda_mode = cuda
        self.pretrain = pretrain
        self.ablation = ablation
        self.model = model
        self.optimizer = optimizer
        self.max_gnorm = max_gnorm
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.total_iters = 0
        self.cur_epoch = 0
        self.harvester = AllTripletSelector()
        self.verbose = verbose
        self.save_cp = save_cp
        self.device = next(self.model.parameters()).device
        self.logger = logger
        self.history = {
            'train_loss': [],
            'train_loss_batch': [],
            'ce_loss': [],
            'ce_loss_batch': [],
            'bin_loss': [],
            'bin_loss_batch': []
        }
        self.disc_label_smoothing = label_smoothing
        self.best_e2e_eer, self.best_cos_eer = np.inf, np.inf

        if label_smoothing > 0.0:
            self.ce_criterion = LabelSmoothingLoss(
                label_smoothing, lbl_set_size=self.model.n_classes)
        else:
            self.ce_criterion = torch.nn.CrossEntropyLoss()

        if self.valid_loader is not None:
            self.history['e2e_eer'] = []
            self.history['cos_eer'] = []

        if checkpoint_epoch is not None:
            self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))
Ejemplo n.º 9
0
class TrainLoop(object):
    def __init__(self,
                 model,
                 optimizer,
                 train_loader,
                 valid_loader,
                 max_gnorm,
                 label_smoothing,
                 verbose=-1,
                 cp_name=None,
                 save_cp=False,
                 checkpoint_path=None,
                 checkpoint_epoch=None,
                 pretrain=False,
                 ablation=False,
                 cuda=True,
                 logger=None):
        if checkpoint_path is None:
            # Save to current directory
            self.checkpoint_path = os.getcwd()
        else:
            self.checkpoint_path = checkpoint_path
            if not os.path.isdir(self.checkpoint_path):
                os.mkdir(self.checkpoint_path)

        self.save_epoch_fmt = os.path.join(
            self.checkpoint_path, cp_name) if cp_name else os.path.join(
                self.checkpoint_path, 'checkpoint_{}ep.pt')
        self.cuda_mode = cuda
        self.pretrain = pretrain
        self.ablation = ablation
        self.model = model
        self.optimizer = optimizer
        self.max_gnorm = max_gnorm
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.total_iters = 0
        self.cur_epoch = 0
        self.harvester = AllTripletSelector()
        self.verbose = verbose
        self.save_cp = save_cp
        self.device = next(self.model.parameters()).device
        self.logger = logger
        self.history = {
            'train_loss': [],
            'train_loss_batch': [],
            'ce_loss': [],
            'ce_loss_batch': [],
            'bin_loss': [],
            'bin_loss_batch': []
        }
        self.disc_label_smoothing = label_smoothing
        self.best_e2e_eer, self.best_cos_eer = np.inf, np.inf

        if label_smoothing > 0.0:
            self.ce_criterion = LabelSmoothingLoss(
                label_smoothing, lbl_set_size=self.model.n_classes)
        else:
            self.ce_criterion = torch.nn.CrossEntropyLoss()

        if self.valid_loader is not None:
            self.history['e2e_eer'] = []
            self.history['cos_eer'] = []

        if checkpoint_epoch is not None:
            self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))

    def train(self, n_epochs=1, save_every=1, eval_every=1000):

        while (self.cur_epoch < n_epochs):

            self.cur_epoch += 1

            np.random.seed()
            if isinstance(self.train_loader.dataset, Loader):
                self.train_loader.dataset.update_lists()

            if self.verbose > 0:
                print(' ')
                print('Epoch {}/{}'.format(self.cur_epoch, n_epochs))
                train_iter = tqdm(enumerate(self.train_loader))
            else:
                train_iter = enumerate(self.train_loader)

            if self.pretrain:
                self.save_epoch_cp = False
                ce_epoch = 0.0
                for t, batch in train_iter:
                    ce = self.pretrain_step(batch)
                    self.history['train_loss_batch'].append(ce)
                    ce_epoch += ce
                    if self.logger:
                        self.logger.add_scalar('Train/Cross entropy', ce,
                                               self.total_iters)
                        self.logger.add_scalar(
                            'Info/LR',
                            self.optimizer.optimizer.param_groups[0]['lr'],
                            self.total_iters)

                    self.total_iters += 1

                self.history['train_loss'].append(ce_epoch / (t + 1))

                if self.verbose > 0:
                    print('Train loss: {:0.4f}'.format(
                        self.history['train_loss'][-1]))

            else:
                self.save_epoch_cp = False
                train_loss_epoch = 0.0
                ce_loss_epoch = 0.0
                bin_loss_epoch = 0.0
                for t, batch in train_iter:
                    train_loss, ce_loss, bin_loss = self.train_step(batch)
                    self.history['train_loss_batch'].append(train_loss)
                    self.history['ce_loss_batch'].append(ce_loss)
                    self.history['bin_loss_batch'].append(bin_loss)
                    train_loss_epoch += train_loss
                    ce_loss_epoch += ce_loss
                    bin_loss_epoch += bin_loss

                    self.total_iters += 1

                    if self.logger:
                        self.logger.add_scalar('Train/Total train Loss',
                                               train_loss, self.total_iters)
                        self.logger.add_scalar('Train/Binary class. Loss',
                                               bin_loss, self.total_iters)
                        self.logger.add_scalar('Train/Cross enropy', ce_loss,
                                               self.total_iters)
                        self.logger.add_scalar(
                            'Info/LR',
                            self.optimizer.optimizer.param_groups[0]['lr'],
                            self.total_iters)

                    if self.total_iters % eval_every == 0:
                        self.evaluate()
                        if self.save_cp and (
                                self.history['e2e_eer'][-1] <
                                np.min([np.inf] + self.history['e2e_eer'][:-1])
                                or self.history['cos_eer'][-1] < np.min(
                                    [np.inf] + self.history['cos_eer'][:-1])):
                            self.checkpointing()
                            self.save_epoch_cp = True

                self.history['train_loss'].append(train_loss_epoch / (t + 1))
                self.history['ce_loss'].append(ce_loss_epoch / (t + 1))
                self.history['bin_loss'].append(bin_loss_epoch / (t + 1))

                if self.verbose > 0:
                    print(' ')
                    print('Total train loss: {:0.4f}'.format(
                        self.history['train_loss'][-1]))
                    print('CE loss: {:0.4f}'.format(
                        self.history['ce_loss'][-1]))
                    print('Binary classification loss: {:0.4f}'.format(
                        self.history['bin_loss'][-1]))
                    print('Current LR: {}'.format(
                        self.optimizer.optimizer.param_groups[0]['lr']))
                    print(' ')

            if self.save_cp and self.cur_epoch % save_every == 0 and not self.save_epoch_cp:
                self.checkpointing()

        if self.verbose > 0:
            print('Training done!')

        if self.valid_loader is not None:
            if self.verbose > 0:
                print(
                    'Best e2e eer and corresponding epoch and iteration: {:0.4f}, {}, {}'
                    .format(np.min(self.history['e2e_eer']),
                            self.best_e2e_eer_epoch,
                            self.best_e2e_eer_iteration))
                print(
                    'Best cos eer and corresponding epoch and iteration: {:0.4f}, {}, {}'
                    .format(np.min(self.history['cos_eer']),
                            self.best_cos_eer_epoch,
                            self.best_cos_eer_iteration))

            return [
                np.min(self.history['e2e_eer']),
                np.min(self.history['cos_eer'])
            ]
        else:
            return [np.min(self.history['train_loss'])]

    def train_step(self, batch):

        self.model.train()
        self.optimizer.zero_grad()

        if isinstance(self.train_loader.dataset, Loader):
            x_1, x_2, x_3, x_4, x_5, y = batch
            x = torch.cat([x_1, x_2, x_3, x_4, x_5], dim=0)
            y = torch.cat(5 * [y], dim=0).squeeze().contiguous()
        else:
            x, y = batch

        x = x.to(self.device, non_blocking=True)
        y = y.to(self.device, non_blocking=True)

        if random.random() > 0.5:
            x += torch.randn_like(x) * random.choice([1e-4, 1e-5])

        embeddings, out = self.model.forward(x)
        embeddings_norm = F.normalize(embeddings, p=2, dim=1)

        if not self.ablation:
            ce_loss = self.ce_criterion(
                self.model.out_proj(embeddings_norm, y), y)
        else:
            ce_loss = 0.0

        # Get all triplets now for bin classifier
        triplets_idx = self.harvester.get_triplets(embeddings.detach(), y)
        triplets_idx = triplets_idx.to(self.device, non_blocking=True)

        emb_a = torch.index_select(embeddings, 0, triplets_idx[:, 0])
        emb_p = torch.index_select(embeddings, 0, triplets_idx[:, 1])
        emb_n = torch.index_select(embeddings, 0, triplets_idx[:, 2])

        emb_ap = torch.cat([emb_a, emb_p], 1)
        emb_an = torch.cat([emb_a, emb_n], 1)
        emb_ = torch.cat([emb_ap, emb_an], 0)

        y_ = torch.cat([
            torch.rand(emb_ap.size(0)) * self.disc_label_smoothing +
            (1.0 - self.disc_label_smoothing),
            torch.rand(emb_an.size(0)) * self.disc_label_smoothing
        ], 0) if isinstance(
            self.ce_criterion, LabelSmoothingLoss) else torch.cat(
                [torch.ones(emb_ap.size(0)),
                 torch.zeros(emb_an.size(0))], 0)
        y_ = y_.to(self.device, non_blocking=True)

        pred_bin = self.model.forward_bin(emb_).squeeze()

        loss_bin = torch.nn.BCELoss()(pred_bin, y_)

        loss = ce_loss + loss_bin
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.max_gnorm)
        self.optimizer.step()

        if self.logger:
            self.logger.add_scalar('Info/Grad_norm', grad_norm,
                                   self.total_iters)

        return loss.item(), ce_loss.item(
        ) if not self.ablation else 0.0, loss_bin.item()

    def pretrain_step(self, batch):

        self.model.train()
        self.optimizer.zero_grad()

        x, y = batch

        x, y = x.to(self.device,
                    non_blocking=True), y.to(self.device,
                                             non_blocking=True).squeeze()

        embeddings, out = self.model.forward(utt)
        embeddings_norm = F.normalize(embeddings, p=2, dim=1)

        loss = F.cross_entropy(self.model.out_proj(embeddings_norm, y), y)

        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.max_gnorm)
        self.optimizer.step()
        return loss.item()

    def valid(self, batch):

        self.model.eval()

        with torch.no_grad():

            if isinstance(self.valid_loader.dataset, Loader):
                x_1, x_2, x_3, x_4, x_5, y = batch
                x = torch.cat([x_1, x_2, x_3, x_4, x_4], dim=0)
                y = torch.cat(5 * [y], dim=0).squeeze().contiguous()
            else:
                x, y = batch

            x = x.to(self.device, non_blocking=True)
            y = y.to(self.device, non_blocking=True)

            embeddings, out = self.model.forward(x)

            # Get all triplets now for bin classifier
            triplets_idx = self.harvester.get_triplets(embeddings.detach(), y)
            triplets_idx = triplets_idx.to(self.device, non_blocking=True)

            emb_a = torch.index_select(embeddings, 0, triplets_idx[:, 0])
            emb_p = torch.index_select(embeddings, 0, triplets_idx[:, 1])
            emb_n = torch.index_select(embeddings, 0, triplets_idx[:, 2])

            emb_ap = torch.cat([emb_a, emb_p], 1)
            emb_an = torch.cat([emb_a, emb_n], 1)

            e2e_scores_p = self.model.forward_bin(emb_ap).squeeze()
            e2e_scores_n = self.model.forward_bin(emb_an).squeeze()
            cos_scores_p = torch.nn.functional.cosine_similarity(emb_a, emb_p)
            cos_scores_n = torch.nn.functional.cosine_similarity(emb_a, emb_n)

        return np.concatenate([
            e2e_scores_p.detach().cpu().numpy(),
            e2e_scores_n.detach().cpu().numpy()
        ], 0), np.concatenate([
            cos_scores_p.detach().cpu().numpy(),
            cos_scores_n.detach().cpu().numpy()
        ], 0), np.concatenate(
            [np.ones(e2e_scores_p.size(0)),
             np.zeros(e2e_scores_n.size(0))], 0)

    def evaluate(self):

        if self.verbose > 0:
            print('\nIteration - Epoch: {} - {}'.format(
                self.total_iters, self.cur_epoch))

        e2e_scores, cos_scores, labels = None, None, None

        for t, batch in enumerate(self.valid_loader):
            e2e_scores_batch, cos_scores_batch, labels_batch = self.valid(
                batch)

            try:
                e2e_scores = np.concatenate([e2e_scores, e2e_scores_batch], 0)
                cos_scores = np.concatenate([cos_scores, cos_scores_batch], 0)
                labels = np.concatenate([labels, labels_batch], 0)
            except:
                e2e_scores, cos_scores, labels = e2e_scores_batch, cos_scores_batch, labels_batch

        self.history['e2e_eer'].append(compute_eer(labels, e2e_scores))
        self.history['cos_eer'].append(compute_eer(labels, cos_scores))

        if self.history['e2e_eer'][-1] < self.best_e2e_eer:
            self.best_e2e_eer = self.history['e2e_eer'][-1]
            self.best_e2e_eer_epoch = self.cur_epoch
            self.best_e2e_eer_iteration = self.total_iters

        if self.history['cos_eer'][-1] < self.best_cos_eer:
            self.best_cos_eer = self.history['cos_eer'][-1]
            self.best_cos_eer_epoch = self.cur_epoch
            self.best_cos_eer_iteration = self.total_iters

        if self.logger:
            self.logger.add_scalar('Valid/E2E EER',
                                   self.history['e2e_eer'][-1],
                                   self.total_iters)
            self.logger.add_scalar('Valid/Best E2E EER',
                                   np.min(self.history['e2e_eer']),
                                   self.total_iters)
            self.logger.add_scalar('Valid/Cosine EER',
                                   self.history['cos_eer'][-1],
                                   self.total_iters)
            self.logger.add_scalar('Valid/Best Cosine EER',
                                   np.min(self.history['cos_eer']),
                                   self.total_iters)
            self.logger.add_pr_curve('E2E ROC',
                                     labels=labels,
                                     predictions=e2e_scores,
                                     global_step=self.total_iters)
            self.logger.add_pr_curve('Cosine ROC',
                                     labels=labels,
                                     predictions=cos_scores,
                                     global_step=self.total_iters)
            self.logger.add_histogram('Valid/COS_Scores',
                                      values=cos_scores,
                                      global_step=self.total_iters)
            self.logger.add_histogram('Valid/E2E_Scores',
                                      values=e2e_scores,
                                      global_step=self.total_iters)
            self.logger.add_histogram('Valid/Labels',
                                      values=labels,
                                      global_step=self.total_iters)

        if self.verbose > 0:
            print(' ')
            print(
                'Current e2e EER, best e2e EER, and epoch - iteration: {:0.4f}, {:0.4f}, {}, {}'
                .format(self.history['e2e_eer'][-1],
                        np.min(self.history['e2e_eer']),
                        self.best_e2e_eer_epoch, self.best_e2e_eer_iteration))
            print(
                'Current cos EER, best cos EER, and epoch - iteration: {:0.4f}, {:0.4f}, {}, {}'
                .format(self.history['cos_eer'][-1],
                        np.min(self.history['cos_eer']),
                        self.best_cos_eer_epoch, self.best_cos_eer_iteration))

    def checkpointing(self):

        # Checkpointing
        if self.verbose > 0:
            print('Checkpointing...')
        ckpt = {
            'model_state': self.model.state_dict(),
            'dropout_prob': self.model.dropout_prob,
            'n_hidden': self.model.n_hidden,
            'hidden_size': self.model.hidden_size,
            'sm_type': self.model.sm_type,
            'n_classes': self.model.n_classes,
            'emb_size': self.model.emb_size,
            'r_proj_size': self.model.r_proj_size,
            'optimizer_state': self.optimizer.state_dict(),
            'history': self.history,
            'total_iters': self.total_iters,
            'cur_epoch': self.cur_epoch
        }
        try:
            torch.save(ckpt, self.save_epoch_fmt.format(self.cur_epoch))
        except:
            torch.save(ckpt, self.save_epoch_fmt)

    def load_checkpoint(self, ckpt):

        if os.path.isfile(ckpt):

            ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage)
            # Load model state
            self.model.load_state_dict(ckpt['model_state'])
            # Load optimizer state
            self.optimizer.load_state_dict(ckpt['optimizer_state'])
            # Load history
            self.history = ckpt['history']
            self.total_iters = ckpt['total_iters']
            self.cur_epoch = ckpt['cur_epoch']
            if self.cuda_mode:
                self.model = self.model.cuda(self.device)

        else:
            print('No checkpoint found at: {}'.format(ckpt))

    def print_grad_norms(self):
        norm = 0.0
        for params in list(self.model.parameters()):
            norm += params.grad.norm(2).item()
        print('Sum of grads norms: {}'.format(norm))
Ejemplo n.º 10
0
class TrainLoop(object):

	def __init__(self, model, optimizer, train_loader, valid_loader, margin, lambda_, max_gnorm, patience, lr_factor, label_smoothing, verbose=-1, cp_name=None, save_cp=False, checkpoint_path=None, checkpoint_epoch=None, pretrain=False, swap=False, cuda=True, logger=None):
		if checkpoint_path is None:
			# Save to current directory
			self.checkpoint_path = os.getcwd()
		else:
			self.checkpoint_path = checkpoint_path
			if not os.path.isdir(self.checkpoint_path):
				os.mkdir(self.checkpoint_path)

		self.save_epoch_fmt = os.path.join(self.checkpoint_path, cp_name) if cp_name else os.path.join(self.checkpoint_path, 'checkpoint_{}ep.pt')
		self.cuda_mode = cuda
		self.pretrain = pretrain
		self.model = model
		self.optimizer = optimizer
		self.patience = patience
		self.max_gnorm = max_gnorm
		self.lr_factor = lr_factor
		self.lambda_ = lambda_
		self.swap = swap
		self.margin = margin
		self.train_loader = train_loader
		self.valid_loader = valid_loader
		self.total_iters = 0
		self.cur_epoch = 0
		self.harvester = HardestNegativeTripletSelector(margin=0.1, cpu=not self.cuda_mode)
		self.harvester_val = AllTripletSelector()
		self.verbose = verbose
		self.save_cp = save_cp
		self.device = next(self.model.parameters()).device
		self.history = {'train_loss': [], 'train_loss_batch': [], 'ce_loss': [], 'ce_loss_batch': [], 'triplet_loss': [], 'triplet_loss_batch': []}
		self.disc_label_smoothing = label_smoothing*0.5
		self.base_lr = self.optimizer.param_groups[0]['lr']
		self.logger = logger

		if label_smoothing>0.0:
			self.ce_criterion = LabelSmoothingLoss(label_smoothing, lbl_set_size=1000)
		else:
			self.ce_criterion = torch.nn.CrossEntropyLoss()

		if self.valid_loader is not None:
			self.history['cos_eer'] = []
			self.history['acc_1'] = []
			self.history['acc_5'] = []

		if checkpoint_epoch is not None:
			self.load_checkpoint(self.save_epoch_fmt.format(checkpoint_epoch))

	def train(self, n_epochs=1, save_every=1):

		while (self.cur_epoch < n_epochs):

			np.random.seed()
			if isinstance(self.train_loader.dataset, Loader):
				self.train_loader.dataset.update_lists()

			adjust_learning_rate(self.optimizer, self.cur_epoch, self.base_lr, self.patience, self.lr_factor)

			if self.verbose>1:
				print(' ')
				print('Epoch {}/{}'.format(self.cur_epoch+1, n_epochs))
				train_iter = tqdm(enumerate(self.train_loader))
			else:
				train_iter = enumerate(self.train_loader)

			if self.pretrain:

				ce_epoch=0.0
				for t, batch in train_iter:
					ce = self.pretrain_step(batch)
					self.history['train_loss_batch'].append(ce)
					ce_epoch+=ce
					self.logger.add_scalar('Train/Cross entropy', ce, self.total_iters)
					self.logger.add_scalar('Info/LR', self.optimizer.param_groups[0]['lr'], self.total_iters)
					self.total_iters += 1

				self.history['train_loss'].append(ce_epoch/(t+1))

				if self.verbose>1:
					print('Train loss: {:0.4f}'.format(self.history['train_loss'][-1]))

			else:

				train_loss_epoch=0.0
				ce_loss_epoch=0.0
				triplet_loss_epoch=0.0
				for t, batch in train_iter:
					train_loss, ce_loss, triplet_loss = self.train_step(batch)
					self.history['train_loss_batch'].append(train_loss)
					self.history['ce_loss_batch'].append(ce_loss)
					self.history['triplet_loss_batch'].append(triplet_loss)
					train_loss_epoch+=train_loss
					ce_loss_epoch+=ce_loss
					triplet_loss_epoch+=triplet_loss
					if self.logger:
						self.logger.add_scalar('Train/Total train Loss', train_loss, self.total_iters)
						self.logger.add_scalar('Train/Triplet Loss', triplet_loss, self.total_iters)
						self.logger.add_scalar('Train/Cross enropy', ce_loss, self.total_iters)
						self.logger.add_scalar('Info/LR', self.optimizer.param_groups[0]['lr'], self.total_iters)
					self.total_iters += 1

				self.history['train_loss'].append(train_loss_epoch/(t+1))
				self.history['ce_loss'].append(ce_loss_epoch/(t+1))
				self.history['triplet_loss'].append(triplet_loss_epoch/(t+1))

				if self.verbose>1:
					print(' ')
					print('Total train loss: {:0.4f}'.format(self.history['train_loss'][-1]))
					print('CE loss: {:0.4f}'.format(self.history['ce_loss'][-1]))
					print('triplet_loss loss: {:0.4f}'.format(self.history['triplet_loss'][-1]))
					print(' ')

			if self.valid_loader is not None:

				tot_correct_1, tot_correct_5, tot_ = 0, 0, 0
				cos_scores, labels = None, None

				for t, batch in enumerate(self.valid_loader):
					correct_1, correct_5, total, cos_scores_batch, labels_batch = self.valid(batch)

					try:
						cos_scores = np.concatenate([cos_scores, cos_scores_batch], 0)
						labels = np.concatenate([labels, labels_batch], 0)
					except:
						cos_scores, labels = cos_scores_batch, labels_batch

					tot_correct_1 += correct_1
					tot_correct_5 += correct_5
					tot_ += total

				self.history['cos_eer'].append(compute_eer(labels, cos_scores))
				self.history['acc_1'].append(float(tot_correct_1)/tot_)
				self.history['acc_5'].append(float(tot_correct_5)/tot_)
				if self.logger:
					self.logger.add_scalar('Valid/Cosine EER', self.history['cos_eer'][-1], self.total_iters-1)
					self.logger.add_scalar('Valid/Best Cosine EER', np.min(self.history['cos_eer']), self.total_iters-1)
					self.logger.add_scalar('Valid/ACC-1', self.history['acc_1'][-1], self.total_iters-1)
					self.logger.add_scalar('Valid/Best ACC-1', np.max(self.history['acc_1']), self.total_iters-1)
					self.logger.add_scalar('Valid/ACC-5', self.history['acc_5'][-1], self.total_iters-1)
					self.logger.add_scalar('Valid/Best ACC-5', np.max(self.history['acc_5']), self.total_iters-1)
					self.logger.add_pr_curve('Cosine ROC', labels=labels, predictions=cos_scores, global_step=self.total_iters-1)
					self.logger.add_histogram('Valid/COS_Scores', values=cos_scores, global_step=self.total_iters-1)
					self.logger.add_histogram('Valid/Labels', values=labels, global_step=self.total_iters-1)

				if self.verbose>1:
					print(' ')
					print('Current cos EER, best cos EER, and epoch: {:0.4f}, {:0.4f}, {}'.format(self.history['cos_eer'][-1], np.min(self.history['cos_eer']), 1+np.argmin(self.history['cos_eer'])))
					print('Current Top 1 Acc, best Top 1 Acc, and epoch: {:0.4f}, {:0.4f}, {}'.format(self.history['acc_1'][-1], np.max(self.history['acc_1']), 1+np.argmax(self.history['acc_1'])))
					print('Current Top 5 Acc, best Top 5 Acc, and epoch: {:0.4f}, {:0.4f}, {}'.format(self.history['acc_5'][-1], np.max(self.history['acc_5']), 1+np.argmax(self.history['acc_5'])))

			if self.verbose>1:
				print('Current LR: {}'.format(self.optimizer.param_groups[0]['lr']))

			self.cur_epoch += 1

			if self.valid_loader is not None and self.save_cp and (self.cur_epoch % save_every == 0 or self.history['cos_eer'][-1] < np.min([np.inf]+self.history['cos_eer'][:-1])):
					self.checkpointing()
			elif self.save_cp and self.cur_epoch % save_every == 0:
					self.checkpointing()

		if self.verbose>1:
			print('Training done!')

		if self.valid_loader is not None:
			if self.verbose>1:
				print('Best cos eer and corresponding epoch: {:0.4f}, {}'.format(np.min(self.history['cos_eer']), 1+np.argmin(self.history['cos_eer'])))

			return [np.min(self.history['cos_eer'])]
		else:
			return [np.min(self.history['train_loss'])]

	def train_step(self, batch):

		self.model.train()
		self.optimizer.zero_grad()

		if isinstance(self.train_loader.dataset, Loader):
			x_1, x_2, x_3, x_4, x_5, y = batch
			x = torch.cat([x_1, x_2, x_3, x_4, x_5], dim=0)
			y = torch.cat(5*[y], dim=0).squeeze().contiguous()
		else:
			x, y = batch

		x = x.to(self.device, non_blocking=True)
		y = y.to(self.device, non_blocking=True)

		embeddings, out = self.model.forward(x)
		embeddings_norm = F.normalize(embeddings, p=2, dim=1)

		ce_loss = self.ce_criterion(self.model.out_proj(out, y), y)

		# Get all triplets now for bin classifier
		triplets_idx, entropy_indices = self.harvester.get_triplets(embeddings_norm.detach(), y)
		triplets_idx = triplets_idx.to(self.device, non_blocking=True)

		emb_a = torch.index_select(embeddings_norm, 0, triplets_idx[:, 0])
		emb_p = torch.index_select(embeddings_norm, 0, triplets_idx[:, 1])
		emb_n = torch.index_select(embeddings_norm, 0, triplets_idx[:, 2])

		loss_metric = self.triplet_loss(emb_a, emb_p, emb_n)

		loss = ce_loss + loss_metric

		entropy_regularizer = torch.nn.functional.pairwise_distance(embeddings_norm, embeddings_norm[entropy_indices,:]).mean()
		loss -= entropy_regularizer*self.lambda_

		loss.backward()
		grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_gnorm)
		self.optimizer.step()

		if self.logger:
			self.logger.add_scalar('Info/Grad_norm', grad_norm, self.total_iters)

		return loss.item(), ce_loss.item(), loss_metric.item()


	def pretrain_step(self, batch):

		self.model.train()
		self.optimizer.zero_grad()

		x, y = batch

		x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True).squeeze()

		embeddings, out = self.model.forward(utt)

		loss = F.cross_entropy(self.model.out_proj(out, y), y)

		loss.backward()
		self.optimizer.step()
		grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_gnorm)

		if self.logger:
			self.logger.add_scalar('Info/Grad_norm', grad_norm, self.total_iters)

		return loss.item()


	def valid(self, batch):

		self.model.eval()

		with torch.no_grad():

			if isinstance(self.valid_loader.dataset, Loader):
				x_1, x_2, x_3, x_4, x_5, y = batch
				x = torch.cat([x_1, x_2, x_3, x_4, x_5], dim=0)
				y = torch.cat(5*[y], dim=0).squeeze().contiguous()
			else:
				x, y = batch

			x = x.to(self.device, non_blocking=True)
			y = y.to(self.device, non_blocking=True)

			embeddings, out = self.model.forward(x)
			embeddings_norm = F.normalize(embeddings, p=2, dim=1)

			out = self.model.out_proj(out, y)

			pred = F.softmax(out, dim=1)
			(correct_1, correct_5) = correct_topk(pred, y, (1,5))

			# Get all triplets now for bin classifier
			triplets_idx = self.harvester_val.get_triplets(embeddings.detach(), y)
			triplets_idx = triplets_idx.to(self.device, non_blocking=True)

			emb_a = torch.index_select(embeddings_norm, 0, triplets_idx[:, 0])
			emb_p = torch.index_select(embeddings_norm, 0, triplets_idx[:, 1])
			emb_n = torch.index_select(embeddings_norm, 0, triplets_idx[:, 2])

			cos_scores_p = torch.nn.functional.cosine_similarity(emb_a, emb_p)
			cos_scores_n = torch.nn.functional.cosine_similarity(emb_a, emb_n)

		return correct_1, correct_5, x.size(0), np.concatenate([cos_scores_p.detach().cpu().numpy(), cos_scores_n.detach().cpu().numpy()], 0), np.concatenate([np.ones(cos_scores_p.size(0)), np.zeros(cos_scores_p.size(0))], 0)

	def triplet_loss(self, emba, embp, embn, reduce_=True):

		loss_ = torch.nn.TripletMarginLoss(margin=self.margin, p=2.0, eps=1e-06, swap=self.swap, reduction='mean' if reduce_ else 'none')(emba, embp, embn)

		return loss_

	def checkpointing(self):

		# Checkpointing
		if self.verbose>1:
			print('Checkpointing...')
		ckpt = {'model_state': self.model.state_dict(),
		'optimizer_state': self.optimizer.state_dict(),
		'history': self.history,
		'total_iters': self.total_iters,
		'cur_epoch': self.cur_epoch}
		try:
			torch.save(ckpt, self.save_epoch_fmt.format(self.cur_epoch))
		except:
			torch.save(ckpt, self.save_epoch_fmt)

	def load_checkpoint(self, ckpt):

		if os.path.isfile(ckpt):

			ckpt = torch.load(ckpt, map_location = lambda storage, loc: storage)
			# Load model state
			self.model.load_state_dict(ckpt['model_state'])
			# Load optimizer state
			self.optimizer.load_state_dict(ckpt['optimizer_state'])
			# Load history
			self.history = ckpt['history']
			self.total_iters = ckpt['total_iters']
			self.cur_epoch = ckpt['cur_epoch']
			if self.cuda_mode:
				self.model = self.model.cuda(self.device)

		else:
			print('No checkpoint found at: {}'.format(ckpt))

	def print_grad_norms(self):
		norm = 0.0
		for params in list(self.model.parameters()):
			norm+=params.grad.norm(2).item()
		print('Sum of grads norms: {}'.format(norm))