def __init_dataset__(self, type, emb_loss_names): init_as_triplet = self.triplet in emb_loss_names or type == self.triplet if self.train_ds is not None and init_as_triplet == self.is_triplet_dataset and self.is_distill_ds == False: return dataset_params = { "data_path": self.data_path, "batch_size": self.batch_size, "random_status": self.random_status, "image_per_class": self.image_per_class, "teacher_model_interf": self.teacher_model_interf, } if init_as_triplet: print(">>>> Init triplet dataset...") if self.data_path.endswith(".tfrecord"): print( ">>>> Combining tfrecord dataset with triplet is NOT recommended." ) self.train_ds, self.steps_per_epoch = data.prepare_distill_dataset_tfrecord( **dataset_params) else: aa = data.Triplet_dataset(**dataset_params) self.train_ds, self.steps_per_epoch = aa.ds, aa.steps_per_epoch self.is_triplet_dataset = True else: print(">>>> Init softmax dataset...") if self.data_path.endswith(".tfrecord"): self.train_ds, self.steps_per_epoch = data.prepare_distill_dataset_tfrecord( **dataset_params) else: self.train_ds, self.steps_per_epoch = data.prepare_dataset( **dataset_params) self.is_triplet_dataset = False if tf.distribute.has_strategy(): self.train_ds = self.train_ds.with_options(self.data_options) label_spec = self.train_ds.element_spec[-1] if isinstance(label_spec, tuple): # dataset with embedding values self.is_distill_ds = True self.teacher_emb_size = label_spec[0].shape[-1] self.classes = label_spec[1].shape[-1] if type == self.distill: # Loss is distill type: [label * n, embedding] self.train_ds = self.train_ds.map( lambda xx, yy: (xx, yy[1:] * len(emb_loss_names) + yy[:1])) elif (self.distill in emb_loss_names and len(emb_loss_names) != 1 ) or (self.distill not in emb_loss_names and len(emb_loss_names) != 0): # Will attach distill loss as embedding loss, and there are other embedding losses: [embedding, label * n] label_data_len = len( emb_loss_names) if self.distill in emb_loss_names else len( emb_loss_names) + 1 self.train_ds = self.train_ds.map( lambda xx, yy: (xx, yy[:1] + yy[1:] * label_data_len)) else: self.is_distill_ds = False self.classes = label_spec.shape[-1]
def __init_dataset_triplet__(self): if self.train_ds == None or self.is_triplet_dataset == False: print(">>>> Init triplet dataset...") # batch_size = int(self.batch_size / 4 * 1.5) batch_size = self.batch_size // 4 tt = data.Triplet_dataset(self.data_path, batch_size=batch_size, random_status=self.random_status, random_crop=(100, 100, 3)) self.train_ds = tt.train_dataset self.classes = self.train_ds.element_spec[-1].shape[-1] self.is_triplet_dataset = True
def __init_dataset__(self, type): if type == self.triplet: if self.train_ds == None or self.is_triplet_dataset == False: print(">>>> Init triplet dataset...") # batch_size = int(self.batch_size / 4 * 1.5) batch_size = self.batch_size // 4 tt = data.Triplet_dataset(self.data_path, batch_size=batch_size, random_status=self.random_status) self.train_ds, self.steps_per_epoch = tt.train_dataset, tt.steps_per_epoch self.is_triplet_dataset = True else: if self.train_ds == None or self.is_triplet_dataset == True: print(">>>> Init softmax dataset...") self.train_ds, self.steps_per_epoch, self.classes = data.prepare_dataset( self.data_path, batch_size=self.batch_size, random_status=self.random_status) self.is_triplet_dataset = False