Beispiel #1
0
    def train(self, train_schedule, initial_epoch=0):
        for sch in train_schedule:
            type = sch.get("type", None) or self.__init_type_by_loss__(
                sch["loss"])
            print(">>>> Train %s..." % type)

            self.basic_model.trainable = True
            self.__init_optimizer__(sch.get("optimizer", None))
            self.__init_dataset__(type)
            self.__init_model__(type)
            if sch.get("centerloss", False):
                print(">>>> Train centerloss...")
                if type == self.triplet:
                    print(">>>> Center loss combined with triplet, skip")
                    continue
                center_loss = sch["loss"]
                if center_loss.__class__.__name__ != losses.CenterLoss.__name__:
                    feature_dim = self.basic_model.output_shape[-1]
                    initial_file = self.basic_model.name + "_centers.npy"
                    logits_loss = sch["loss"]
                    center_loss = losses.CenterLoss(self.classes,
                                                    feature_dim=feature_dim,
                                                    factor=1.0,
                                                    initial_file=initial_file,
                                                    logits_loss=logits_loss)
                    sch["loss"] = center_loss
                self.model = keras.models.Model(
                    self.model.inputs[0],
                    keras.layers.concatenate(
                        [self.basic_model.outputs[0], self.model.outputs[-1]]))
            else:
                center_loss = None
            self.__init_metrics_callbacks__(type, center_loss,
                                            sch.get("bottleneckOnly", False))

            if sch.get("bottleneckOnly", False):
                print(">>>> Train bottleneckOnly...")
                self.basic_model.trainable = False
                self.__basic_train__(sch["loss"],
                                     sch["epoch"],
                                     initial_epoch=0)
                self.basic_model.trainable = True
            else:
                self.__basic_train__(sch["loss"],
                                     initial_epoch + sch["epoch"],
                                     initial_epoch=initial_epoch)
                initial_epoch += sch["epoch"]

            print(
                ">>>> Train %s DONE!!! epochs = %s, model.stop_training = %s" %
                (type, self.model.history.epoch, self.model.stop_training))
            print(">>>> My history:")
            self.my_hist.print_hist()
            if self.model.stop_training == True:
                print(">>>> But it's an early stop, break...")
                break
            print()
Beispiel #2
0
def generate_embeddings(train_dataset,
                        val_dataset,
                        device,
                        embed_type,
                        n_epochs=10,
                        batch_size=32,
                        save_path=None):

    train_dl = torch.utils.data.DataLoader(train_dataset,
                                           shuffle=True,
                                           batch_size=batch_size)
    val_dl = torch.utils.data.DataLoader(val_dataset,
                                         shuffle=True,
                                         batch_size=batch_size)

    gnet = network.GenreNet(embed_type).to(device)

    if embed_type not in [
            'softmax', 'center-softmax', 'triplet', 'sphere', 'cos'
    ]:
        raise Exception('Invalid embedding type!')

    if embed_type == 'softmax':
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam([{
            'params': gnet.net.parameters(),
            'lr': 1e-6
        }, {
            'params': gnet.classifier.parameters(),
            'lr': 1e-3
        }, {
            'params':
            gnet.embedding_layer.parameters(),
            'lr':
            1e-3
        }])

    elif embed_type == 'center-softmax':
        criterion = torch.nn.CrossEntropyLoss()
        center_loss = losses.CenterLoss(num_classes=10,
                                        feat_dim=32,
                                        use_gpu=True)

        optimizer = torch.optim.Adam([{
            'params': gnet.net.parameters(),
            'lr': 1e-5
        }, {
            'params': gnet.classifier.parameters(),
            'lr': 1e-3
        }, {
            'params':
            gnet.embedding_layer.parameters(),
            'lr':
            1e-3
        }, {
            'params': center_loss.parameters(),
            'lr': 1e-3
        }])

    elif embed_type == 'cos':
        criterion = losses.CosLoss()
        optimizer = torch.optim.Adam([{
            'params': gnet.net.parameters(),
            'lr': 1e-6
        }, {
            'params': gnet.classifier.parameters(),
            'lr': 1e-3
        }, {
            'params':
            gnet.embedding_layer.parameters(),
            'lr':
            1e-3
        }])

    train_tracker = []
    val_tracker = []
    for epoch in range(n_epochs):
        gnet.train()
        epoch_losses = []
        total_samples = 1e-5
        correct_samples = 0
        for i, batch in enumerate(train_dl):
            gnet.zero_grad()
            optimizer.zero_grad()

            X, y = batch[0].to(device), batch[1].to(device).long()
            embeddings, predictions = gnet(X)

            if embed_type == 'softmax':
                _, y_pred = torch.max(predictions, 1)
                total_samples += y.size(0)
                correct_samples += (y_pred == y).sum().item()

                loss = criterion(predictions, y.long())

            elif embed_type == 'sphere':
                loss, acc_batch = criterion(predictions, y.long())
                correct_samples += y.size(0) * acc_batch
                total_samples += y.size(0)

            elif embed_type == 'cos':
                loss, acc_batch = criterion(predictions, y.long())
                correct_samples += y.size(0) * acc_batch
                total_samples += y.size(0)

            elif embed_type == 'center-softmax':
                _, y_pred = torch.max(predictions, 1)
                total_samples += y.size(0)
                correct_samples += (y_pred == y).sum().item()

                closs = center_loss(embeddings, y)
                loss = criterion(predictions, y.long()) + closs

            epoch_losses.append(loss.item())
            loss.backward()
            optimizer.step()

        epoch_loss = np.mean(epoch_losses)
        epoch_acc = correct_samples / total_samples
        train_tracker.append((epoch_loss, epoch_acc))
        if (epoch + 1) % 5 == 0:
            print("Train Loss after epoch {} = {}".format(
                epoch, np.mean(epoch_losses)))
            print("Train Accuracy after epoch {} = {}".format(
                epoch, correct_samples / total_samples))
            #torch.save(gnet, './checkpoints/gnet_model_{}_epoch_{}.pth'.format(embed_type,epoch))

        epoch_losses = []
        total_samples = 1e-5
        correct_samples = 0
        gnet.eval()
        for i, batch in enumerate(val_dl):

            gnet.zero_grad()
            optimizer.zero_grad()

            X, y = batch[0].to(device), batch[1].to(device).long()
            embeddings, predictions = gnet(X)

            if embed_type == 'softmax':
                _, y_pred = torch.max(predictions, 1)
                total_samples += y.size(0)
                correct_samples += (y_pred == y).sum().item()

                loss = criterion(predictions, y.long())

            elif embed_type == 'sphere':
                loss, acc_batch = criterion(predictions, y.long())
                correct_samples += y.size(0) * acc_batch
                total_samples += y.size(0)

            elif embed_type == 'cos':
                loss, acc_batch = criterion(predictions, y.long())
                correct_samples += y.size(0) * acc_batch
                total_samples += y.size(0)

            elif embed_type == 'center-softmax':
                _, y_pred = torch.max(predictions, 1)
                total_samples += y.size(0)
                correct_samples += (y_pred == y).sum().item()

                closs = center_loss(embeddings, y)
                loss = criterion(predictions, y.long()) + closs

            epoch_losses.append(loss.item())

        epoch_loss = np.mean(epoch_losses)
        epoch_acc = correct_samples / total_samples
        val_tracker.append((epoch_loss, epoch_acc))
        if (epoch + 1) % 5 == 0:
            print("Val Loss after epoch {} = {}".format(
                epoch, np.mean(epoch_losses)))
            print("Val Accuracy after epoch {} = {}".format(
                epoch, correct_samples / total_samples))
            print('\n')
            visualizer.visualize_embeddings(val_dl, gnet, device)

    if not save_path is None:

        train_embeddings = None
        for i, batch in enumerate(train_dl):
            X, y = batch[0].to(device), batch[1].to(device)
            embeddings, predictions = gnet(X)
            embeddings = embeddings.detach().cpu().numpy()
            if train_embeddings is None:
                train_embeddings = embeddings
            else:
                train_embeddings = np.concatenate(
                    [train_embeddings, embeddings])

        val_embeddings = None
        for i, batch in enumerate(val_dl):
            X, y = batch[0].to(device), batch[1].to(device)
            embeddings, predictions = gnet(X)
            embeddings = embeddings.detach().cpu().numpy()
            if val_embeddings is None:
                val_embeddings = embeddings
            else:
                val_embeddings = np.concatenate([val_embeddings, embeddings])

        np.save('{}/{}_{}.npy'.format(save_path, embed_type, 'train'),
                train_embeddings)
        np.save('{}/{}_{}.npy'.format(save_path, embed_type, 'val'),
                val_embeddings)
Beispiel #3
0
    def train(self, train_schedule, initial_epoch=0):
        train_schedule = [train_schedule] if isinstance(
            train_schedule, dict) else train_schedule
        for sch in train_schedule:
            if sch.get("loss", None) is None:
                continue
            cur_loss = sch["loss"]
            type = sch.get("type",
                           None) or self.__init_type_by_loss__(cur_loss)
            print(">>>> Train %s..." % type)

            if sch.get("triplet", False) or sch.get(
                    "tripletAll", False) or type == self.triplet:
                self.__init_dataset_triplet__()
            else:
                self.__init_dataset_softmax__()

            self.basic_model.trainable = True
            self.__init_optimizer__(sch.get("optimizer", None))
            self.__init_model__(type, sch.get("lossTopK", 1))

            # loss_weights
            cur_loss = [cur_loss]
            self.callbacks = self.my_evals + self.custom_callbacks + self.basic_callbacks
            loss_weights = None
            if sch.get("centerloss", False) and type != self.center:
                print(">>>> Attach centerloss...")
                emb_shape = self.basic_model.output_shape[-1]
                initial_file = os.path.splitext(
                    self.save_path)[0] + "_centers.npy"
                center_loss = losses.CenterLoss(self.classes,
                                                emb_shape=emb_shape,
                                                initial_file=initial_file)
                cur_loss = [center_loss, *cur_loss]
                loss_weights = {ii: 1.0 for ii in self.model.output_names}
                nns = self.model.output_names
                self.model = keras.models.Model(
                    self.model.inputs[0],
                    self.basic_model.outputs + self.model.outputs)
                self.model.output_names[0] = self.center + "_embedding"
                for id, nn in enumerate(nns):
                    self.model.output_names[id + 1] = nn
                self.callbacks = self.my_evals + self.custom_callbacks + [
                    center_loss.save_centers_callback
                ] + self.basic_callbacks
                loss_weights.update(
                    {self.model.output_names[0]: float(sch["centerloss"])})

            if (sch.get("triplet", False)
                    or sch.get("tripletAll", False)) and type != self.triplet:
                alpha = sch.get("alpha", 0.35)
                triplet_loss = losses.BatchHardTripletLoss(
                    alpha=alpha) if sch.get(
                        "triplet", False) else losses.BatchAllTripletLoss(
                            alpha=alpha)
                print(">>>> Attach tripletloss: %s, alpha = %f..." %
                      (triplet_loss.__class__.__name__, alpha))

                cur_loss = [triplet_loss, *cur_loss]
                loss_weights = loss_weights if loss_weights is not None else {
                    ii: 1.0
                    for ii in self.model.output_names
                }
                nns = self.model.output_names
                self.model = keras.models.Model(
                    self.model.inputs[0],
                    self.basic_model.outputs + self.model.outputs)
                self.model.output_names[0] = self.triplet + "_embedding"
                for id, nn in enumerate(nns):
                    self.model.output_names[id + 1] = nn
                loss_weights.update({
                    self.model.output_names[0]:
                    float(
                        sch.get("triplet", False)
                        or sch.get("tripletAll", False))
                })

            if self.is_distiller:
                loss_weights = [1, sch.get("distill", 7)]
                print(">>>> Train distiller model...")
                self.model = keras.models.Model(
                    self.model.inputs[0],
                    [self.model.outputs[-1], self.basic_model.outputs[0]])
                cur_loss = [cur_loss[-1], losses.distiller_loss]

            print(">>>> loss_weights:", loss_weights)
            self.metrics = {
                ii: None if "embedding" in ii else "accuracy"
                for ii in self.model.output_names
            }

            try:
                import tensorflow_addons as tfa
            except:
                pass
            else:
                if isinstance(
                        self.optimizer, tfa.optimizers.weight_decay_optimizers.
                        DecoupledWeightDecayExtension):
                    print(">>>> Insert weight decay callback...")
                    lr_base, wd_base = self.optimizer.lr.numpy(
                    ), self.optimizer.weight_decay.numpy()
                    wd_callback = myCallbacks.OptimizerWeightDecay(
                        lr_base, wd_base)
                    self.callbacks.insert(
                        -2, wd_callback)  # should be after lr_scheduler

            if sch.get("bottleneckOnly", False):
                print(">>>> Train bottleneckOnly...")
                self.basic_model.trainable = False
                self.callbacks = self.callbacks[len(
                    self.my_evals):]  # Exclude evaluation callbacks
                self.__basic_train__(cur_loss,
                                     sch["epoch"],
                                     initial_epoch=0,
                                     loss_weights=loss_weights)
                self.basic_model.trainable = True
            else:
                self.__basic_train__(cur_loss,
                                     initial_epoch + sch["epoch"],
                                     initial_epoch=initial_epoch,
                                     loss_weights=loss_weights)
                initial_epoch += sch["epoch"]

            print(
                ">>>> Train %s DONE!!! epochs = %s, model.stop_training = %s" %
                (type, self.model.history.epoch, self.model.stop_training))
            print(">>>> My history:")
            self.my_hist.print_hist()
            if self.model.stop_training == True:
                print(">>>> But it's an early stop, break...")
                break
            print()
Beispiel #4
0
    def train(self, train_schedule, initial_epoch=0):
        for sch in train_schedule:
            if sch.get("loss", None) is None:
                continue
            cur_loss = sch["loss"]
            self.basic_model.trainable = True
            self.__init_optimizer__(sch.get("optimizer", None))

            if isinstance(cur_loss, losses.TripletLossWapper
                          ) and cur_loss.logits_loss is not None:
                type = sch.get("type", None) or self.__init_type_by_loss__(
                    cur_loss.logits_loss)
                cur_loss.feature_dim = self.basic_model.output_shape[-1]
                print(">>>> Train Triplet + %s, feature_dim = %d ..." %
                      (type, cur_loss.feature_dim))
                self.__init_dataset__(self.triplet)
                self.__init_model__(type)
                self.model = keras.models.Model(
                    self.model.inputs[0],
                    keras.layers.concatenate(
                        [self.basic_model.outputs[0], self.model.outputs[-1]]))
                type = self.triplet + " + " + type
            else:
                type = sch.get("type",
                               None) or self.__init_type_by_loss__(cur_loss)
                print(">>>> Train %s..." % type)
                self.__init_dataset__(type)
                self.__init_model__(type)

            if sch.get("centerloss", False):
                print(">>>> Train centerloss...")
                center_loss = cur_loss
                if not isinstance(center_loss, losses.CenterLoss):
                    feature_dim = self.basic_model.output_shape[-1]
                    # initial_file = self.basic_model.name + "_centers.npy"
                    initial_file = os.path.splitext(
                        self.save_path)[0] + "_centers.npy"
                    logits_loss = cur_loss
                    center_loss = losses.CenterLoss(self.classes,
                                                    feature_dim=feature_dim,
                                                    factor=1.0,
                                                    initial_file=initial_file,
                                                    logits_loss=logits_loss)
                    cur_loss = center_loss
                    # self.my_hist.custom_obj["centerloss"] = lambda : cur_loss.centerloss
                self.model = keras.models.Model(
                    self.model.inputs[0],
                    keras.layers.concatenate(
                        [self.basic_model.outputs[0], self.model.outputs[-1]]))
                self.callbacks = self.my_evals + [
                    center_loss.save_centers_callback
                ] + self.basic_callbacks
            else:
                self.callbacks = self.my_evals + self.basic_callbacks
            self.metrics = None if type == self.triplet else [
                self.logits_accuracy
            ]

            if sch.get("bottleneckOnly", False):
                print(">>>> Train bottleneckOnly...")
                self.basic_model.trainable = False
                self.callbacks = self.callbacks[len(
                    self.my_evals):]  # Exclude evaluation callbacks
                self.__basic_train__(cur_loss, sch["epoch"], initial_epoch=0)
                self.basic_model.trainable = True
            else:
                self.__basic_train__(cur_loss,
                                     initial_epoch + sch["epoch"],
                                     initial_epoch=initial_epoch)
                initial_epoch += sch["epoch"]

            print(
                ">>>> Train %s DONE!!! epochs = %s, model.stop_training = %s" %
                (type, self.model.history.epoch, self.model.stop_training))
            print(">>>> My history:")
            self.my_hist.print_hist()
            if self.model.stop_training == True:
                print(">>>> But it's an early stop, break...")
                break
            print()