Example #1
0
    def __init__(self, args, cfg, checkpoint_dir):
        self.batch_size = cfg.train.batch_size
        self.learning_rate = cfg.train.lr
        self.epochs = cfg.train.epochs
        self.start_epoch = 1
        self.lr_decay_epochs = cfg.train.lr_decay
        self.log_interval = cfg.train.log_inter
        self.checkpoint_dir = checkpoint_dir
        self.checkpoint_interval = cfg.train.ckpt_inter
        self.lambda_ = cfg.train.beta
        self.attr_dims = cfg.attr_dims
        self.device = torch.device(
            'cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu')
        self.triplet_batch = 4

        self.fnet, self.optimizer, self.im_size = self.build_model(cfg)
        if os.path.exists(cfg.ckpt_name) and args.fine_tuning:
            pth = glob(os.path.join(cfg.ckpt_name, "ckpt_epoch_*.pth"))
            pth = sorted(pth, 
                         key=lambda p: int(os.path.basename(p).replace("ckpt_epoch_", "").replace(".pth", "")), 
                         reverse=True)
            if pth:
                self.load(pth[0])
                self.start_epoch = int(
                        ''.join([c for c in os.path.basename(pth[0]) if c.isdigit()])
                        ) + 1
            
        self.attr_data, self.dataset_size, self.data_loader = self.prepare_dataloader(cfg)
        #self.attr_data = torch.from_numpy(self.attr_data).to(self.device)
        self.online_zsl_loss = losses.ZeroShotLearningLoss(self.attr_data)
        
        if cfg.train.triplet_mode == "batch_all":
            self.online_triplet_loss = \
                        losses.BatchAllTripletLoss(self.device, 
                                                   self.batch_size // self.triplet_batch, 
                                                   self.triplet_batch)
        else:
            self.online_triplet_loss = \
                        losses.BatchHardTripletLoss(self.device,
                                                    self.batch_size // self.triplet_batch,
                                                    self.triplet_batch)
Example #2
0
    def train(self, train_schedule, initial_epoch=0):
        train_schedule = [train_schedule] if isinstance(
            train_schedule, dict) else train_schedule
        for sch in train_schedule:
            if sch.get("loss", None) is None:
                continue
            cur_loss = sch["loss"]
            type = sch.get("type",
                           None) or self.__init_type_by_loss__(cur_loss)
            print(">>>> Train %s..." % type)

            if sch.get("triplet", False) or sch.get(
                    "tripletAll", False) or type == self.triplet:
                self.__init_dataset_triplet__()
            else:
                self.__init_dataset_softmax__()

            self.basic_model.trainable = True
            self.__init_optimizer__(sch.get("optimizer", None))
            self.__init_model__(type, sch.get("lossTopK", 1))

            # loss_weights
            cur_loss = [cur_loss]
            self.callbacks = self.my_evals + self.custom_callbacks + self.basic_callbacks
            loss_weights = None
            if sch.get("centerloss", False) and type != self.center:
                print(">>>> Attach centerloss...")
                emb_shape = self.basic_model.output_shape[-1]
                initial_file = os.path.splitext(
                    self.save_path)[0] + "_centers.npy"
                center_loss = losses.CenterLoss(self.classes,
                                                emb_shape=emb_shape,
                                                initial_file=initial_file)
                cur_loss = [center_loss, *cur_loss]
                loss_weights = {ii: 1.0 for ii in self.model.output_names}
                nns = self.model.output_names
                self.model = keras.models.Model(
                    self.model.inputs[0],
                    self.basic_model.outputs + self.model.outputs)
                self.model.output_names[0] = self.center + "_embedding"
                for id, nn in enumerate(nns):
                    self.model.output_names[id + 1] = nn
                self.callbacks = self.my_evals + self.custom_callbacks + [
                    center_loss.save_centers_callback
                ] + self.basic_callbacks
                loss_weights.update(
                    {self.model.output_names[0]: float(sch["centerloss"])})

            if (sch.get("triplet", False)
                    or sch.get("tripletAll", False)) and type != self.triplet:
                alpha = sch.get("alpha", 0.35)
                triplet_loss = losses.BatchHardTripletLoss(
                    alpha=alpha) if sch.get(
                        "triplet", False) else losses.BatchAllTripletLoss(
                            alpha=alpha)
                print(">>>> Attach tripletloss: %s, alpha = %f..." %
                      (triplet_loss.__class__.__name__, alpha))

                cur_loss = [triplet_loss, *cur_loss]
                loss_weights = loss_weights if loss_weights is not None else {
                    ii: 1.0
                    for ii in self.model.output_names
                }
                nns = self.model.output_names
                self.model = keras.models.Model(
                    self.model.inputs[0],
                    self.basic_model.outputs + self.model.outputs)
                self.model.output_names[0] = self.triplet + "_embedding"
                for id, nn in enumerate(nns):
                    self.model.output_names[id + 1] = nn
                loss_weights.update({
                    self.model.output_names[0]:
                    float(
                        sch.get("triplet", False)
                        or sch.get("tripletAll", False))
                })

            if self.is_distiller:
                loss_weights = [1, sch.get("distill", 7)]
                print(">>>> Train distiller model...")
                self.model = keras.models.Model(
                    self.model.inputs[0],
                    [self.model.outputs[-1], self.basic_model.outputs[0]])
                cur_loss = [cur_loss[-1], losses.distiller_loss]

            print(">>>> loss_weights:", loss_weights)
            self.metrics = {
                ii: None if "embedding" in ii else "accuracy"
                for ii in self.model.output_names
            }

            try:
                import tensorflow_addons as tfa
            except:
                pass
            else:
                if isinstance(
                        self.optimizer, tfa.optimizers.weight_decay_optimizers.
                        DecoupledWeightDecayExtension):
                    print(">>>> Insert weight decay callback...")
                    lr_base, wd_base = self.optimizer.lr.numpy(
                    ), self.optimizer.weight_decay.numpy()
                    wd_callback = myCallbacks.OptimizerWeightDecay(
                        lr_base, wd_base)
                    self.callbacks.insert(
                        -2, wd_callback)  # should be after lr_scheduler

            if sch.get("bottleneckOnly", False):
                print(">>>> Train bottleneckOnly...")
                self.basic_model.trainable = False
                self.callbacks = self.callbacks[len(
                    self.my_evals):]  # Exclude evaluation callbacks
                self.__basic_train__(cur_loss,
                                     sch["epoch"],
                                     initial_epoch=0,
                                     loss_weights=loss_weights)
                self.basic_model.trainable = True
            else:
                self.__basic_train__(cur_loss,
                                     initial_epoch + sch["epoch"],
                                     initial_epoch=initial_epoch,
                                     loss_weights=loss_weights)
                initial_epoch += sch["epoch"]

            print(
                ">>>> Train %s DONE!!! epochs = %s, model.stop_training = %s" %
                (type, self.model.history.epoch, self.model.stop_training))
            print(">>>> My history:")
            self.my_hist.print_hist()
            if self.model.stop_training == True:
                print(">>>> But it's an early stop, break...")
                break
            print()
        n_classes_test = 8

    # data_x = partition['train'] + partition['validation'] + partition['test']
    # data_y = {**labels['train'], **labels['validation'], **labels['test']}
    data_x = partition['test']
    data_y = labels['test']
    test_dataset = PairLoader(data_x, data_y, data_source=data_type)
    test_batch_sampler = BalanceBatchSampler(dataset=test_dataset,
                                             n_classes=n_classes_test,
                                             n_samples=n_samples_test)
    test_loader = DataLoader(test_dataset,
                             batch_sampler=test_batch_sampler,
                             num_workers=num_workers)
    if triplet_method == "batch_hard":
        loss_fn = losses.BatchHardTripletLoss(margin=margin,
                                              squared=False,
                                              soft_margin=soft_margin)

    elif triplet_method == "batch_hardv2":
        loss_fn = losses.BatchHardTripletLoss_v2(margin=margin,
                                                 squared=False,
                                                 soft_margin=soft_margin)

    elif triplet_method == "batch_all":
        loss_fn = losses.BatchAllTripletLoss(margin=margin,
                                             squared=False,
                                             soft_margin=soft_margin)
    # rt = '/nfs/nas4/marzieh/marzieh/puf/ckpt/batch_hardv2/'
    # model_filename = rt + 'Run004,modelTriplet,Epoch_345,acc_0.999688.tar'
    # model_filename = Reporter(ckpt_root=os.path.join(ROOT_DIR, 'ckpt'),
    #                           exp=triplet_method, monitor='acc').select_best(run=run_name).selected_ckpt