def train(self, train_schedule, initial_epoch=0): for sch in train_schedule: type = sch.get("type", None) or self.__init_type_by_loss__( sch["loss"]) print(">>>> Train %s..." % type) self.basic_model.trainable = True self.__init_optimizer__(sch.get("optimizer", None)) self.__init_dataset__(type) self.__init_model__(type) if sch.get("centerloss", False): print(">>>> Train centerloss...") if type == self.triplet: print(">>>> Center loss combined with triplet, skip") continue center_loss = sch["loss"] if center_loss.__class__.__name__ != losses.CenterLoss.__name__: feature_dim = self.basic_model.output_shape[-1] initial_file = self.basic_model.name + "_centers.npy" logits_loss = sch["loss"] center_loss = losses.CenterLoss(self.classes, feature_dim=feature_dim, factor=1.0, initial_file=initial_file, logits_loss=logits_loss) sch["loss"] = center_loss self.model = keras.models.Model( self.model.inputs[0], keras.layers.concatenate( [self.basic_model.outputs[0], self.model.outputs[-1]])) else: center_loss = None self.__init_metrics_callbacks__(type, center_loss, sch.get("bottleneckOnly", False)) if sch.get("bottleneckOnly", False): print(">>>> Train bottleneckOnly...") self.basic_model.trainable = False self.__basic_train__(sch["loss"], sch["epoch"], initial_epoch=0) self.basic_model.trainable = True else: self.__basic_train__(sch["loss"], initial_epoch + sch["epoch"], initial_epoch=initial_epoch) initial_epoch += sch["epoch"] print( ">>>> Train %s DONE!!! epochs = %s, model.stop_training = %s" % (type, self.model.history.epoch, self.model.stop_training)) print(">>>> My history:") self.my_hist.print_hist() if self.model.stop_training == True: print(">>>> But it's an early stop, break...") break print()
def generate_embeddings(train_dataset, val_dataset, device, embed_type, n_epochs=10, batch_size=32, save_path=None): train_dl = torch.utils.data.DataLoader(train_dataset, shuffle=True, batch_size=batch_size) val_dl = torch.utils.data.DataLoader(val_dataset, shuffle=True, batch_size=batch_size) gnet = network.GenreNet(embed_type).to(device) if embed_type not in [ 'softmax', 'center-softmax', 'triplet', 'sphere', 'cos' ]: raise Exception('Invalid embedding type!') if embed_type == 'softmax': criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam([{ 'params': gnet.net.parameters(), 'lr': 1e-6 }, { 'params': gnet.classifier.parameters(), 'lr': 1e-3 }, { 'params': gnet.embedding_layer.parameters(), 'lr': 1e-3 }]) elif embed_type == 'center-softmax': criterion = torch.nn.CrossEntropyLoss() center_loss = losses.CenterLoss(num_classes=10, feat_dim=32, use_gpu=True) optimizer = torch.optim.Adam([{ 'params': gnet.net.parameters(), 'lr': 1e-5 }, { 'params': gnet.classifier.parameters(), 'lr': 1e-3 }, { 'params': gnet.embedding_layer.parameters(), 'lr': 1e-3 }, { 'params': center_loss.parameters(), 'lr': 1e-3 }]) elif embed_type == 'cos': criterion = losses.CosLoss() optimizer = torch.optim.Adam([{ 'params': gnet.net.parameters(), 'lr': 1e-6 }, { 'params': gnet.classifier.parameters(), 'lr': 1e-3 }, { 'params': gnet.embedding_layer.parameters(), 'lr': 1e-3 }]) train_tracker = [] val_tracker = [] for epoch in range(n_epochs): gnet.train() epoch_losses = [] total_samples = 1e-5 correct_samples = 0 for i, batch in enumerate(train_dl): gnet.zero_grad() optimizer.zero_grad() X, y = batch[0].to(device), batch[1].to(device).long() embeddings, predictions = gnet(X) if embed_type == 'softmax': _, y_pred = torch.max(predictions, 1) total_samples += y.size(0) correct_samples += (y_pred == y).sum().item() loss = criterion(predictions, y.long()) elif embed_type == 'sphere': loss, acc_batch = criterion(predictions, y.long()) correct_samples += y.size(0) * acc_batch total_samples += y.size(0) elif embed_type == 'cos': loss, acc_batch = criterion(predictions, y.long()) correct_samples += y.size(0) * acc_batch total_samples += y.size(0) elif embed_type == 'center-softmax': _, y_pred = torch.max(predictions, 1) total_samples += y.size(0) correct_samples += (y_pred == y).sum().item() closs = center_loss(embeddings, y) loss = criterion(predictions, y.long()) + closs epoch_losses.append(loss.item()) loss.backward() optimizer.step() epoch_loss = np.mean(epoch_losses) epoch_acc = correct_samples / total_samples train_tracker.append((epoch_loss, epoch_acc)) if (epoch + 1) % 5 == 0: print("Train Loss after epoch {} = {}".format( epoch, np.mean(epoch_losses))) print("Train Accuracy after epoch {} = {}".format( epoch, correct_samples / total_samples)) #torch.save(gnet, './checkpoints/gnet_model_{}_epoch_{}.pth'.format(embed_type,epoch)) epoch_losses = [] total_samples = 1e-5 correct_samples = 0 gnet.eval() for i, batch in enumerate(val_dl): gnet.zero_grad() optimizer.zero_grad() X, y = batch[0].to(device), batch[1].to(device).long() embeddings, predictions = gnet(X) if embed_type == 'softmax': _, y_pred = torch.max(predictions, 1) total_samples += y.size(0) correct_samples += (y_pred == y).sum().item() loss = criterion(predictions, y.long()) elif embed_type == 'sphere': loss, acc_batch = criterion(predictions, y.long()) correct_samples += y.size(0) * acc_batch total_samples += y.size(0) elif embed_type == 'cos': loss, acc_batch = criterion(predictions, y.long()) correct_samples += y.size(0) * acc_batch total_samples += y.size(0) elif embed_type == 'center-softmax': _, y_pred = torch.max(predictions, 1) total_samples += y.size(0) correct_samples += (y_pred == y).sum().item() closs = center_loss(embeddings, y) loss = criterion(predictions, y.long()) + closs epoch_losses.append(loss.item()) epoch_loss = np.mean(epoch_losses) epoch_acc = correct_samples / total_samples val_tracker.append((epoch_loss, epoch_acc)) if (epoch + 1) % 5 == 0: print("Val Loss after epoch {} = {}".format( epoch, np.mean(epoch_losses))) print("Val Accuracy after epoch {} = {}".format( epoch, correct_samples / total_samples)) print('\n') visualizer.visualize_embeddings(val_dl, gnet, device) if not save_path is None: train_embeddings = None for i, batch in enumerate(train_dl): X, y = batch[0].to(device), batch[1].to(device) embeddings, predictions = gnet(X) embeddings = embeddings.detach().cpu().numpy() if train_embeddings is None: train_embeddings = embeddings else: train_embeddings = np.concatenate( [train_embeddings, embeddings]) val_embeddings = None for i, batch in enumerate(val_dl): X, y = batch[0].to(device), batch[1].to(device) embeddings, predictions = gnet(X) embeddings = embeddings.detach().cpu().numpy() if val_embeddings is None: val_embeddings = embeddings else: val_embeddings = np.concatenate([val_embeddings, embeddings]) np.save('{}/{}_{}.npy'.format(save_path, embed_type, 'train'), train_embeddings) np.save('{}/{}_{}.npy'.format(save_path, embed_type, 'val'), val_embeddings)
def train(self, train_schedule, initial_epoch=0): train_schedule = [train_schedule] if isinstance( train_schedule, dict) else train_schedule for sch in train_schedule: if sch.get("loss", None) is None: continue cur_loss = sch["loss"] type = sch.get("type", None) or self.__init_type_by_loss__(cur_loss) print(">>>> Train %s..." % type) if sch.get("triplet", False) or sch.get( "tripletAll", False) or type == self.triplet: self.__init_dataset_triplet__() else: self.__init_dataset_softmax__() self.basic_model.trainable = True self.__init_optimizer__(sch.get("optimizer", None)) self.__init_model__(type, sch.get("lossTopK", 1)) # loss_weights cur_loss = [cur_loss] self.callbacks = self.my_evals + self.custom_callbacks + self.basic_callbacks loss_weights = None if sch.get("centerloss", False) and type != self.center: print(">>>> Attach centerloss...") emb_shape = self.basic_model.output_shape[-1] initial_file = os.path.splitext( self.save_path)[0] + "_centers.npy" center_loss = losses.CenterLoss(self.classes, emb_shape=emb_shape, initial_file=initial_file) cur_loss = [center_loss, *cur_loss] loss_weights = {ii: 1.0 for ii in self.model.output_names} nns = self.model.output_names self.model = keras.models.Model( self.model.inputs[0], self.basic_model.outputs + self.model.outputs) self.model.output_names[0] = self.center + "_embedding" for id, nn in enumerate(nns): self.model.output_names[id + 1] = nn self.callbacks = self.my_evals + self.custom_callbacks + [ center_loss.save_centers_callback ] + self.basic_callbacks loss_weights.update( {self.model.output_names[0]: float(sch["centerloss"])}) if (sch.get("triplet", False) or sch.get("tripletAll", False)) and type != self.triplet: alpha = sch.get("alpha", 0.35) triplet_loss = losses.BatchHardTripletLoss( alpha=alpha) if sch.get( "triplet", False) else losses.BatchAllTripletLoss( alpha=alpha) print(">>>> Attach tripletloss: %s, alpha = %f..." % (triplet_loss.__class__.__name__, alpha)) cur_loss = [triplet_loss, *cur_loss] loss_weights = loss_weights if loss_weights is not None else { ii: 1.0 for ii in self.model.output_names } nns = self.model.output_names self.model = keras.models.Model( self.model.inputs[0], self.basic_model.outputs + self.model.outputs) self.model.output_names[0] = self.triplet + "_embedding" for id, nn in enumerate(nns): self.model.output_names[id + 1] = nn loss_weights.update({ self.model.output_names[0]: float( sch.get("triplet", False) or sch.get("tripletAll", False)) }) if self.is_distiller: loss_weights = [1, sch.get("distill", 7)] print(">>>> Train distiller model...") self.model = keras.models.Model( self.model.inputs[0], [self.model.outputs[-1], self.basic_model.outputs[0]]) cur_loss = [cur_loss[-1], losses.distiller_loss] print(">>>> loss_weights:", loss_weights) self.metrics = { ii: None if "embedding" in ii else "accuracy" for ii in self.model.output_names } try: import tensorflow_addons as tfa except: pass else: if isinstance( self.optimizer, tfa.optimizers.weight_decay_optimizers. DecoupledWeightDecayExtension): print(">>>> Insert weight decay callback...") lr_base, wd_base = self.optimizer.lr.numpy( ), self.optimizer.weight_decay.numpy() wd_callback = myCallbacks.OptimizerWeightDecay( lr_base, wd_base) self.callbacks.insert( -2, wd_callback) # should be after lr_scheduler if sch.get("bottleneckOnly", False): print(">>>> Train bottleneckOnly...") self.basic_model.trainable = False self.callbacks = self.callbacks[len( self.my_evals):] # Exclude evaluation callbacks self.__basic_train__(cur_loss, sch["epoch"], initial_epoch=0, loss_weights=loss_weights) self.basic_model.trainable = True else: self.__basic_train__(cur_loss, initial_epoch + sch["epoch"], initial_epoch=initial_epoch, loss_weights=loss_weights) initial_epoch += sch["epoch"] print( ">>>> Train %s DONE!!! epochs = %s, model.stop_training = %s" % (type, self.model.history.epoch, self.model.stop_training)) print(">>>> My history:") self.my_hist.print_hist() if self.model.stop_training == True: print(">>>> But it's an early stop, break...") break print()
def train(self, train_schedule, initial_epoch=0): for sch in train_schedule: if sch.get("loss", None) is None: continue cur_loss = sch["loss"] self.basic_model.trainable = True self.__init_optimizer__(sch.get("optimizer", None)) if isinstance(cur_loss, losses.TripletLossWapper ) and cur_loss.logits_loss is not None: type = sch.get("type", None) or self.__init_type_by_loss__( cur_loss.logits_loss) cur_loss.feature_dim = self.basic_model.output_shape[-1] print(">>>> Train Triplet + %s, feature_dim = %d ..." % (type, cur_loss.feature_dim)) self.__init_dataset__(self.triplet) self.__init_model__(type) self.model = keras.models.Model( self.model.inputs[0], keras.layers.concatenate( [self.basic_model.outputs[0], self.model.outputs[-1]])) type = self.triplet + " + " + type else: type = sch.get("type", None) or self.__init_type_by_loss__(cur_loss) print(">>>> Train %s..." % type) self.__init_dataset__(type) self.__init_model__(type) if sch.get("centerloss", False): print(">>>> Train centerloss...") center_loss = cur_loss if not isinstance(center_loss, losses.CenterLoss): feature_dim = self.basic_model.output_shape[-1] # initial_file = self.basic_model.name + "_centers.npy" initial_file = os.path.splitext( self.save_path)[0] + "_centers.npy" logits_loss = cur_loss center_loss = losses.CenterLoss(self.classes, feature_dim=feature_dim, factor=1.0, initial_file=initial_file, logits_loss=logits_loss) cur_loss = center_loss # self.my_hist.custom_obj["centerloss"] = lambda : cur_loss.centerloss self.model = keras.models.Model( self.model.inputs[0], keras.layers.concatenate( [self.basic_model.outputs[0], self.model.outputs[-1]])) self.callbacks = self.my_evals + [ center_loss.save_centers_callback ] + self.basic_callbacks else: self.callbacks = self.my_evals + self.basic_callbacks self.metrics = None if type == self.triplet else [ self.logits_accuracy ] if sch.get("bottleneckOnly", False): print(">>>> Train bottleneckOnly...") self.basic_model.trainable = False self.callbacks = self.callbacks[len( self.my_evals):] # Exclude evaluation callbacks self.__basic_train__(cur_loss, sch["epoch"], initial_epoch=0) self.basic_model.trainable = True else: self.__basic_train__(cur_loss, initial_epoch + sch["epoch"], initial_epoch=initial_epoch) initial_epoch += sch["epoch"] print( ">>>> Train %s DONE!!! epochs = %s, model.stop_training = %s" % (type, self.model.history.epoch, self.model.stop_training)) print(">>>> My history:") self.my_hist.print_hist() if self.model.stop_training == True: print(">>>> But it's an early stop, break...") break print()