Ejemplo n.º 1
0
    def fit_model(self):
        """
        Fits model. Uses AdamW optimizer, model averaging, and a cosine annealing learning rate schedule.
        """
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer, 100, 2
        )

        self.swa_model = AveragedModel(self.model)
        swa_start = 750
        swa_scheduler = SWALR(
            optimizer, swa_lr=0.001, anneal_epochs=10, anneal_strategy="cos"
        )

        self.model.train()
        self.swa_model.train()
        for epoch in range(1000):
            optimizer.zero_grad()
            output = self.model(self.x)

            loss = -output.log_prob(self.y.view(-1, 1)).sum()

            loss.backward()
            optimizer.step()

            if epoch > swa_start:
                self.swa_model.update_parameters(self.model)
                swa_scheduler.step()
            else:
                scheduler.step()

            if epoch % 10 == 0:
                print(f"Epoch {epoch} complete. Loss: {loss}")
Ejemplo n.º 2
0
def train(num_epochs, model, data_loader, val_loader, val_every, device, file_name):
    learning_rate = 0.0001
    from torch.optim.swa_utils import AveragedModel, SWALR
    from torch.optim.lr_scheduler import CosineAnnealingLR
    from segmentation_models_pytorch.losses import SoftCrossEntropyLoss, JaccardLoss
    from adamp import AdamP

    criterion = [SoftCrossEntropyLoss(smooth_factor=0.1), JaccardLoss('multiclass', classes=12)]
    optimizer = AdamP(params=model.parameters(), lr=learning_rate, weight_decay=1e-6)
    swa_scheduler = SWALR(optimizer, swa_lr=learning_rate)
    swa_model = AveragedModel(model)
    look = Lookahead(optimizer, la_alpha=0.5)

    print('Start training..')
    best_miou = 0
    for epoch in range(num_epochs):
        hist = np.zeros((12, 12))
        model.train()
        for step, (images, masks, _) in enumerate(data_loader):
            loss = 0
            images = torch.stack(images)  # (batch, channel, height, width)
            masks = torch.stack(masks).long()  # (batch, channel, height, width)

            # gpu 연산을 위해 device 할당
            images, masks = images.to(device), masks.to(device)

            # inference
            outputs = model(images)
            for i in criterion:
                loss += i(outputs, masks)
            # loss 계산 (cross entropy loss)

            look.zero_grad()
            loss.backward()
            look.step()

            outputs = torch.argmax(outputs.squeeze(), dim=1).detach().cpu().numpy()
            hist = add_hist(hist, masks.detach().cpu().numpy(), outputs, n_class=12)
            acc, acc_cls, mIoU, fwavacc = label_accuracy_score(hist)
            # step 주기에 따른 loss, mIoU 출력
            if (step + 1) % 25 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, mIoU: {:.4f}'.format(
                    epoch + 1, num_epochs, step + 1, len(data_loader), loss.item(), mIoU))

        # validation 주기에 따른 loss 출력 및 best model 저장
        if (epoch + 1) % val_every == 0:
            avrg_loss, val_miou = validation(epoch + 1, model, val_loader, criterion, device)
            if val_miou > best_miou:
                print('Best performance at epoch: {}'.format(epoch + 1))
                print('Save model in', saved_dir)
                best_miou = val_miou
                save_model(model, file_name = file_name)

        if epoch > 3:
            swa_model.update_parameters(model)
            swa_scheduler.step()
Ejemplo n.º 3
0
def train_model(indep_vars, dep_var, verbose=True):
    """
    Trains MDNVol network. Uses AdamW optimizer with cosine annealing learning rate schedule.
    Ouputs averaged model over the last 25% of training epochs.

    indep_vars: n x m torch tensor containing independent variables
        n = number of data points
        m = number of input variables
    dep_var: n x 1 torch tensor containing single dependent variable
        n = number of data points
        1 = single output variable
    """
    model = MDN(indep_vars.shape[1], 1, 250, 5)
    optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, 100, 2)

    swa_model = AveragedModel(model)
    swa_start = 750
    swa_scheduler = SWALR(optimizer,
                          swa_lr=0.001,
                          anneal_epochs=10,
                          anneal_strategy="cos")

    model.train()
    swa_model.train()
    for epoch in range(1000):
        optimizer.zero_grad()
        output = model(indep_vars)

        loss = -output.log_prob(dep_var).sum()

        loss.backward()
        optimizer.step()

        if epoch > swa_start:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            scheduler.step()

        if epoch % 10 == 0:
            if verbose:
                print(f"Epoch {epoch} complete. Loss: {loss}")

    swa_model.eval()
    return swa_model
Ejemplo n.º 4
0
class SWALRRunner(ClassificationRunner):
    def __init__(self, *args, **kwargs):
        super(SWALRRunner, self).__init__(*args, **kwargs)
        self.swa_model = AveragedModel(self.model)
        self.swa_scheduler = SWALR(self.optimizer, swa_lr=0.05)
        self.swa_start = 5

    def update_scheduler(self, epoch: int) -> None:
        if epoch > self.swa_start:
            self.swa_model.update_parameters(self.model)
            self.swa_scheduler.step()

        else:
            super(SWALRRunner, self).update_scheduler(epoch)

    def train_end(self, outputs):
        update_bn(self.loaders["train"], self.swa_model)
        return super(SWALRRunner, self).train_end(outputs)
Ejemplo n.º 5
0
class Learner:
    def __init__(self, cfg_dir: str, data_loader: DataLoader, model,
                 labels_definition):
        self.cfg = get_conf(cfg_dir)
        self._labels_definition = labels_definition
        #TODO
        self.logger = self.init_logger(self.cfg.logger)
        #self.dataset = CustomDataset(**self.cfg.dataset)
        self.data = data_loader
        #self.val_dataset = CustomDatasetVal(**self.cfg.val_dataset)
        #self.val_data = DataLoader(self.val_dataset, **self.cfg.dataloader)
        # self.logger.log_parameters({"tr_len": len(self.dataset),
        #                             "val_len": len(self.val_dataset)})
        self.model = model
        #self.model._resnet.conv1.apply(init_weights_normal)
        self.device = self.cfg.train_params.device
        self.model = self.model.to(device=self.device)
        if self.cfg.train_params.optimizer.lower() == "adam":
            self.optimizer = optim.Adam(self.model.parameters(),
                                        **self.cfg.adam)
        elif self.cfg.train_params.optimizer.lower() == "rmsprop":
            self.optimizer = optim.RMSprop(self.model.parameters(),
                                           **self.cfg.rmsprop)
        else:
            raise ValueError(
                f"Unknown optimizer {self.cfg.train_params.optimizer}")

        self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=100)
        self.criterion = nn.BCELoss()

        if self.cfg.logger.resume:
            # load checkpoint
            print("Loading checkpoint")
            save_dir = self.cfg.directory.load
            checkpoint = load_checkpoint(save_dir, self.device)
            self.model.load_state_dict(checkpoint["model"])
            self.optimizer.load_state_dict(checkpoint["optimizer"])
            self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            self.epoch = checkpoint["epoch"]
            self.e_loss = checkpoint["e_loss"]
            self.best = checkpoint["best"]
            print(
                f"{datetime.now():%Y-%m-%d %H:%M:%S} "
                f"Loading checkpoint was successful, start from epoch {self.epoch}"
                f" and loss {self.best}")
        else:
            self.epoch = 1
            self.best = np.inf
            self.e_loss = []

        # initialize the early_stopping object
        self.early_stopping = EarlyStopping(
            patience=self.cfg.train_params.patience,
            verbose=True,
            delta=self.cfg.train_params.early_stopping_delta,
        )

        # stochastic weight averaging
        self.swa_model = AveragedModel(self.model)
        self.swa_scheduler = SWALR(self.optimizer, **self.cfg.SWA)

    def train(self, task: VisionTask):
        task.go_to_gpu(self.device)

        visualize_idx = np.random.randint(0, len(self.data), 50)

        while self.epoch <= self.cfg.train_params.epochs:
            running_loss = []
            self.model.train()

            for internel_iter, (images, gt_boxes, gt_labels, ego_labels,
                                counts, img_indexs,
                                wh) in enumerate(self.data):
                self.optimizer.zero_grad()

                # fl = task.get_flat_label(gt_labels)

                m = nn.Sigmoid()
                y = task.get_flat_label(gt_labels)
                x = images

                # move data to device
                x = x.to(device=self.device)
                y = y.to(device=self.device)

                # forward, backward
                encoded_vector = self.model(x)
                out = task.decode(encoded_vector)
                loss = self.criterion(m(out), y)
                loss.backward()
                # check grad norm for debugging
                grad_norm = check_grad_norm(self.model)
                # update
                self.optimizer.step()

                running_loss.append(loss.item())

                #print("Loss:", loss.item())
                #print("grad_norm", grad_norm)

                self.logger.log_metrics(
                    {
                        #"epoch": self.epoch,
                        "batch": internel_iter,
                        "loss": loss.item(),
                        "GradNorm": grad_norm,
                    },
                    epoch=self.epoch)

                #validation
                if internel_iter % 1000 == 0 and self.epoch % 5 == 0:
                    print("Internel iter: ", internel_iter)
                    out = m(out[-1])
                    definitions = []
                    l = task.boundary[1] - task.boundary[0]
                    n_boxes = len(gt_boxes[-1][-1])
                    print("Number of Boxes:", n_boxes)
                    name = "img_" + str(self.epoch) + "_" + str(
                        internel_iter / 1000)
                    for i in range(n_boxes):
                        prediction = out[i * l + 1 + i:i * l + l + 1 + i]
                        prediction = prediction.argmax()
                        definitions.append(name + ": " +
                                           self._labels_definition[
                                               task.get_name()][prediction])

                    print("list", definitions)
                    sz = wh[0][0].item()
                    img = torch.zeros([3, sz, sz])
                    img[0] = images[-1][self.cfg.dataloader.seq_len - 1]
                    img[1] = images[-1][2 * self.cfg.dataloader.seq_len - 1]
                    img[2] = images[-1][3 * self.cfg.dataloader.seq_len - 1]
                    self.logger.log_image(img,
                                          name=name,
                                          image_channels='first')

                #if internel_iter < 10:
                #    sz = wh[0][0].item()
                #    img = torch.zeros([3, sz, sz])
                #    print(img.shape)
                #    print(images.shape)
                #    img[0] = images[-1][self.cfg.dataloader.seq_len -1]
                #    img[1] = images[-1][2*self.cfg.dataloader.seq_len - 1]
                #    img[2] = images[-1][3*self.cfg.dataloader.seq_len - 1]
                #    self.log_image_with_text_on_it(img, gt_labels[-1][-1], task)
                #self.logger.log_image(img, name="v", image_channels='first')

            #bar.close()

            # Visualize
            # self.predict_visualize(index_list=visualize_idx, task=task)

            if self.epoch > self.cfg.train_params.swa_start:
                self.swa_model.update_parameters(self.model)
                self.swa_scheduler.step()
            else:
                self.lr_scheduler.step()

            # validate on val set
            # val_loss, t = self.validate()
            # t /= len(self.val_dataset)

            # average loss for an epoch
            self.e_loss.append(np.mean(running_loss))  # epoch loss
            # print(
            #     f"{datetime.now():%Y-%m-%d %H:%M:%S} Epoch {self.epoch} summary: train Loss: {self.e_loss[-1]:.2f} \t| Val loss: {val_loss:.2f}"
            #     f"\t| time: {t:.3f} seconds"
            # )

            self.logger.log_metrics({
                "epoch": self.epoch,
                "epoch_loss": self.e_loss[-1],
            })

            # early_stopping needs the validation loss to check if it has decreased,
            # and if it has, it will make a checkpoint of the current model
            #self.early_stopping(val_loss, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                self.save()
                break

            if self.epoch % self.cfg.train_params.save_every == 0:
                self.save()

            gc.collect()
            print("Task: " + task.get_name() + " epoch[" + str(self.epoch) +
                  "] finished.")
            self.epoch += 1

        # Update bn statistics for the swa_model at the end
        #if self.epoch >= self.cfg.train_params.swa_start:
#            torch.optim.swa_utils.update_bn(self.data.to(self.device), self.swa_model)
#self.save(name=self.cfg.directory.model_name + "-final" + str(self.epoch) + "-swa")

#macs, params = op_counter(self.model, sample=x)
#print(macs, params)
#self.logger.log_metrics({"GFLOPS": macs[:-1], "#Params": params[:-1], "task name": task.get_name(), "total_loss": self.e_loss[-1]})
        print("Training Finished!")
        return loss

    def train_multi(self, primary_task, auxiliary_tasks):

        # 1- got to gpu fo all tasks
        for auxilary_task in auxiliary_tasks:
            auxilary_task.go_to_gpu(self.device)
        primary_task.go_to_gpu(self.device)

        activation_function = nn.Sigmoid()

        while self.epoch <= self.cfg.train_params.epochs:
            running_loss = []
            self.model.train()

            for internel_iter, (images, gt_boxes, gt_labels, ego_labels,
                                counts, img_indexs,
                                wh) in enumerate(self.data):
                self.optimizer.zero_grad()

                x = images
                x = x.to(device=self.device)
                encoded_vector = self.model(x)

                total_loss = None
                # for auxiliary tasks
                for auxiliary_task in auxiliary_tasks:
                    y = auxiliary_task.get_flat_label(gt_labels)
                    # move data to device
                    y = y.to(device=self.device)
                    # forward
                    out = auxiliary_task.decode(encoded_vector)
                    auxiliary_loss = self.criterion(activation_function(out),
                                                    y)
                    if total_loss is None:
                        total_loss = auxiliary_loss
                    else:
                        total_loss += auxiliary_loss

                # for primary task
                y = primary_task.get_flat_label(gt_labels)
                # move data to device
                y = y.to(device=self.device)
                # forward
                out = primary_task.decode(encoded_vector)
                primary_loss = self.criterion(activation_function(out), y)
                total_loss += primary_loss

                total_loss.backward()
                # check grad norm for debugging
                grad_norm = check_grad_norm(self.model)
                # update
                self.optimizer.step()

                running_loss.append(primary_loss.item())

                self.logger.log_metrics(
                    {
                        # "epoch": self.epoch,
                        "batch": internel_iter,
                        "primary_loss": primary_loss.item(),
                        "GradNorm": grad_norm,
                    },
                    epoch=self.epoch)

                # validation
                if internel_iter % 1000 == 0 and self.epoch % 5 == 0:
                    print("Internel iter: ", internel_iter)
                    out = activation_function(out[-1])
                    definitions = []
                    l = primary_task.boundary[1] - primary_task.boundary[0]
                    n_boxes = len(gt_boxes[-1][-1])
                    print("Number of Boxes:", n_boxes)
                    name = "img_" + str(self.epoch) + "_" + str(
                        internel_iter / 1000)
                    for i in range(n_boxes):
                        prediction = out[i * l + 1 + i:i * l + l + 1 + i]
                        prediction = prediction.argmax()
                        definitions.append(
                            name + ": " + self._labels_definition[
                                primary_task.get_name()][prediction])

                    print("list", definitions)
                    sz = wh[0][0].item()
                    img = torch.zeros([3, sz, sz])
                    img[0] = images[-1][self.cfg.dataloader.seq_len - 1]
                    img[1] = images[-1][2 * self.cfg.dataloader.seq_len - 1]
                    img[2] = images[-1][3 * self.cfg.dataloader.seq_len - 1]

                    img_with_text = draw_text(img, definitions)
                    self.logger.log_image(img_with_text,
                                          name=name,
                                          image_channels='first')

            # Visualize
            # self.predict_visualize(index_list=visualize_idx, task=task)

            if self.epoch > self.cfg.train_params.swa_start:
                self.swa_model.update_parameters(self.model)
                self.swa_scheduler.step()
            else:
                self.lr_scheduler.step()

            # validate on val set
            # val_loss, t = self.validate()
            # t /= len(self.val_dataset)

            # average loss for an epoch
            self.e_loss.append(np.mean(running_loss))  # epoch loss
            # print(
            #     f"{datetime.now():%Y-%m-%d %H:%M:%S} Epoch {self.epoch} summary: train Loss: {self.e_loss[-1]:.2f} \t| Val loss: {val_loss:.2f}"
            #     f"\t| time: {t:.3f} seconds"
            # )

            self.logger.log_metrics({
                "epoch": self.epoch,
                "epoch_loss": self.e_loss[-1],
            })

            # early_stopping needs the validation loss to check if it has decreased,
            # and if it has, it will make a checkpoint of the current model
            # self.early_stopping(val_loss, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                self.save()
                break

            if self.epoch % self.cfg.train_params.save_every == 0:
                self.save()

            gc.collect()
            print("Task: " + primary_task.get_name() + " epoch[" +
                  str(self.epoch) + "] finished.")
            self.epoch += 1

        # Update bn statistics for the swa_model at the end
        # if self.epoch >= self.cfg.train_params.swa_start:
        #            torch.optim.swa_utils.update_bn(self.data.to(self.device), self.swa_model)
        # self.save(name=self.cfg.directory.model_name + "-final" + str(self.epoch) + "-swa")

        # macs, params = op_counter(self.model, sample=x)
        # print(macs, params)
        # self.logger.log_metrics({"GFLOPS": macs[:-1], "#Params": params[:-1], "task name": task.get_name(), "total_loss": self.e_loss[-1]})
        print("Training Finished!")
        return primary_loss

    def predict_visualize(self, index_list, task):
        print("===================================================")
        for i in index_list:
            images, gt_boxes, gt_labels, ego_labels, counts, img_indexs, wh = self.data.dataset.__getitem__(
                i)
            sz = img_indexs[0]

            y = task.get_flat_label(gt_labels)
            x = images

            # move data to device
            x = x.to(device=self.device)
            y = y.to(device=self.device)

            encoded_vector = self.model(x)
            out = task.decode(encoded_vector)
            self.log_image_with_text(img_tensor=images,
                                     out_vector=out,
                                     index=i,
                                     task=task)
        print("===================================================")

    def log_image_with_text(self, img_tensor, out_vector, index, task):
        definitions = []
        label_len = task.boundary[1] - task.boundary[0]
        name = "img_" + str(index)
        i = 0
        while True:
            finished = out_vector[i]
            if finished == True:
                break
            i += 1

            l = out_vector[i, label_len]
            i += label_len
            if len(np.nonzero(l)) > 0:
                definition_idx = np.nonzero(l)[0][0]
                definitions.append(
                    name + ": " +
                    self._labels_definition[task.get_name()][definition_idx])

        print(definitions)
        self.logger.log_image(img_tensor, name=name, image_channels='first')

    def log_image_with_text_on_it(self, img_tensor, labels, task):
        definitions = []
        box_count = len(labels)
        for j in range(min(box_count, VisionTask._max_box_count)):
            l = labels[j]  # len(l) = 149
            l = l[task.boundary[0]:task.boundary[1]]
            if len(np.nonzero(l)) > 0:
                definition_idx = np.nonzero(l)[0][0]
                definitions.append(
                    self._labels_definition[task.get_name()][definition_idx])

        img = draw_text(img_tensor, definitions)
        print(definitions)
        # print(images.shape)
        self.logger.log_image(img_tensor, name="v", image_channels='first')

    # @timeit
    # @torch.no_grad()
    # def validate(self):
    #
    #     self.model.eval()
    #
    #     running_loss = []
    #
    #     for idx, (x, y) in tqdm(enumerate(self.val_data), desc="Validation"):
    #         # move data to device
    #         x = x.to(device=self.device)
    #         y = y.to(device=self.device)
    #
    #         # forward, backward
    #         if self.epoch > self.cfg.train_params.swa_start:
    #             # Update bn statistics for the swa_model
    #             torch.optim.swa_utils.update_bn(self.data, self.swa_model)
    #             out = self.swa_model(x)
    #         else:
    #             out = self.model(x)
    #
    #         loss = self.criterion(out, y)
    #         running_loss.append(loss.item())
    #
    #     # average loss
    #     loss = np.mean(running_loss)
    #
    #     return loss

    def init_logger(self, cfg):
        logger = None
        # Check to see if there is a key in environment:
        EXPERIMENT_KEY = cfg.experiment_key

        # First, let's see if we continue or start fresh:
        CONTINUE_RUN = cfg.resume
        if (EXPERIMENT_KEY is not None):
            # There is one, but the experiment might not exist yet:
            api = comet_ml.API()  # Assumes API key is set in config/env
            try:
                api_experiment = api.get_experiment_by_id(EXPERIMENT_KEY)
            except Exception:
                api_experiment = None
            if api_experiment is not None:
                CONTINUE_RUN = True
                # We can get the last details logged here, if logged:
                # step = int(api_experiment.get_parameters_summary("batch")["valueCurrent"])
                # epoch = int(api_experiment.get_parameters_summary("epochs")["valueCurrent"])

        if CONTINUE_RUN:
            # 1. Recreate the state of ML system before creating experiment
            # otherwise it could try to log params, graph, etc. again
            # ...
            # 2. Setup the existing experiment to carry on:
            logger = comet_ml.ExistingExperiment(
                previous_experiment=EXPERIMENT_KEY,
                log_env_details=True,  # to continue env logging
                log_env_gpu=True,  # to continue GPU logging
                log_env_cpu=True,  # to continue CPU logging
                auto_histogram_weight_logging=True,
                auto_histogram_gradient_logging=True,
                auto_histogram_activation_logging=True)
            # Retrieved from above APIExperiment
            # self.logger.set_epoch(epoch)

        else:
            # 1. Create the experiment first
            #    This will use the COMET_EXPERIMENT_KEY if defined in env.
            #    Otherwise, you could manually set it here. If you don't
            #    set COMET_EXPERIMENT_KEY, the experiment will get a
            #    random key!
            logger = comet_ml.Experiment(
                disabled=cfg.disabled,
                project_name=cfg.project,
                auto_histogram_weight_logging=True,
                auto_histogram_gradient_logging=True,
                auto_histogram_activation_logging=True)
            logger.add_tags(cfg.tags.split())
            logger.log_parameters(self.cfg)

        return logger

    def save(self, name=None):
        checkpoint = {
            "epoch": self.epoch,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "lr_scheduler": self.lr_scheduler.state_dict(),
            "best": self.best,
            "e_loss": self.e_loss
        }

        if name is None and self.epoch >= self.cfg.train_params.swa_start:
            save_name = self.cfg.directory.model_name + str(
                self.epoch) + "-swa"
            checkpoint['model-swa'] = self.swa_model.state_dict()

        elif name is None:
            save_name = self.cfg.directory.model_name + str(self.epoch)

        else:
            save_name = name

        if self.e_loss[-1] < self.best:
            self.best = self.e_loss[-1]
            checkpoint["best"] = self.best
            save_checkpoint(checkpoint, True, self.cfg.directory.save,
                            save_name)
        else:
            save_checkpoint(checkpoint, False, self.cfg.directory.save,
                            save_name)
Ejemplo n.º 6
0
def pseudo_labeling(num_epochs, model, data_loader, val_loader,
                    unlabeled_loader, device, val_every, file_name):
    # Instead of using current epoch we use a "step" variable to calculate alpha_weight
    # This helps the model converge faster
    from torch.optim.swa_utils import AveragedModel, SWALR
    from segmentation_models_pytorch.losses import SoftCrossEntropyLoss, JaccardLoss
    from adamp import AdamP

    criterion = [
        SoftCrossEntropyLoss(smooth_factor=0.1),
        JaccardLoss('multiclass', classes=12)
    ]
    optimizer = AdamP(params=model.parameters(), lr=0.0001, weight_decay=1e-6)
    swa_scheduler = SWALR(optimizer, swa_lr=0.0001)
    swa_model = AveragedModel(model)
    optimizer = Lookahead(optimizer, la_alpha=0.5)

    step = 100
    size = 256
    best_mIoU = 0
    model.train()
    print('Start Pseudo-Labeling..')
    for epoch in range(num_epochs):
        hist = np.zeros((12, 12))
        for batch_idx, (imgs, image_infos) in enumerate(unlabeled_loader):

            # Forward Pass to get the pseudo labels
            # --------------------------------------------- test(unlabelse)를 모델에 통과
            model.eval()
            outs = model(torch.stack(imgs).to(device))
            oms = torch.argmax(outs.squeeze(), dim=1).detach().cpu().numpy()
            oms = torch.Tensor(oms)
            oms = oms.long()
            oms = oms.to(device)

            # --------------------------------------------- 학습

            model.train()
            # Now calculate the unlabeled loss using the pseudo label
            imgs = torch.stack(imgs)
            imgs = imgs.to(device)
            # preds_array = preds_array.to(device)

            output = model(imgs)
            loss = 0
            for each in criterion:
                loss += each(output, oms)

            unlabeled_loss = alpha_weight(step) * loss

            # Backpropogate
            optimizer.zero_grad()
            unlabeled_loss.backward()
            optimizer.step()
            output = torch.argmax(output.squeeze(),
                                  dim=1).detach().cpu().numpy()
            hist = add_hist(hist,
                            oms.detach().cpu().numpy(),
                            output,
                            n_class=12)

            if (batch_idx + 1) % 25 == 0:
                acc, acc_cls, mIoU, fwavacc = label_accuracy_score(hist)
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, mIoU:{:.4f}'.
                      format(epoch + 1, num_epochs, batch_idx + 1,
                             len(unlabeled_loader), unlabeled_loss.item(),
                             mIoU))
            # For every 50 batches train one epoch on labeled data
            # 50배치마다 라벨데이터를 1 epoch학습
            if batch_idx % 50 == 0:

                # Normal training procedure
                for batch_idx, (images, masks, _) in enumerate(data_loader):
                    labeled_loss = 0
                    images = torch.stack(images)
                    # (batch, channel, height, width)
                    masks = torch.stack(masks).long()

                    # gpu 연산을 위해 device 할당
                    images, masks = images.to(device), masks.to(device)

                    output = model(images)

                    for each in criterion:
                        labeled_loss += each(output, masks)

                    optimizer.zero_grad()
                    labeled_loss.backward()
                    optimizer.step()

                # Now we increment step by 1
                step += 1

        if (epoch + 1) % val_every == 0:
            avrg_loss, val_mIoU = validation(epoch + 1, model, val_loader,
                                             criterion, device)
            if val_mIoU > best_mIoU:
                print('Best performance at epoch: {}'.format(epoch + 1))
                print('Save model in', saved_dir)
                best_mIoU = val_mIoU
                save_model(model, file_name=file_name)

        model.train()

        if epoch > 3:
            swa_model.update_parameters(model)
            swa_scheduler.step()
    def fit(
        self,
        train_objectives: Iterable[Tuple[DataLoader, nn.Module]],
        evaluator: SentenceEvaluator = None,
        epochs: int = 1,
        steps_per_epoch=None,
        scheduler: str = 'WarmupLinear',
        warmup_steps: int = 10000,
        optimizer_class: Type[Optimizer] = transformers.AdamW,
        optimizer_params: Dict[str, object] = {
            'lr': 2e-5,
            'eps': 1e-6,
            'correct_bias': False
        },
        weight_decay: float = 0.01,
        evaluation_steps: int = 0,
        output_path: str = None,
        save_best_model: bool = True,
        max_grad_norm: float = 1,
        use_amp: bool = False,
        callback: Callable[[float, int, int], None] = None,
        show_progress_bar: bool = True,
        log_every: int = 100,
        wandb_project_name: str = None,
        wandb_config: Dict[str, object] = {},
        use_swa: bool = False,
        swa_epochs_start: int = 5,
        swa_anneal_epochs: int = 10,
        swa_lr: float = 0.05,
    ):
        """
        Train the model with the given training objective
        Each training objective is sampled in turn for one batch.
        We sample only as many batches from each objective as there are in the smallest one
        to make sure of equal training with each dataset.

        :param train_objectives: Tuples of (DataLoader, LossFunction). Pass more than one for multi-task learning
        :param evaluator: An evaluator (sentence_transformers.evaluation) evaluates the model performance during training on held-out dev data. It is used to determine the best model that is saved to disc.
        :param epochs: Number of epochs for training
        :param steps_per_epoch: Number of training steps per epoch. If set to None (default), one epoch is equal the DataLoader size from train_objectives.
        :param scheduler: Learning rate scheduler. Available schedulers: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
        :param warmup_steps: Behavior depends on the scheduler. For WarmupLinear (default), the learning rate is increased from o up to the maximal learning rate. After these many training steps, the learning rate is decreased linearly back to zero.
        :param optimizer_class: Optimizer
        :param optimizer_params: Optimizer parameters
        :param weight_decay: Weight decay for model parameters
        :param evaluation_steps: If > 0, evaluate the model using evaluator after each number of training steps
        :param output_path: Storage path for the model and evaluation files
        :param save_best_model: If true, the best model (according to evaluator) is stored at output_path
        :param max_grad_norm: Used for gradient normalization.
        :param use_amp: Use Automatic Mixed Precision (AMP). Only for Pytorch >= 1.6.0
        :param callback: Callback function that is invoked after each evaluation.
                It must accept the following three parameters in this order:
                `score`, `epoch`, `steps`
        :param show_progress_bar: If True, output a tqdm progress bar
        """

        if use_amp:
            from torch.cuda.amp import autocast
            scaler = torch.cuda.amp.GradScaler()

        self.to(self._target_device)

        if output_path is not None:
            os.makedirs(output_path, exist_ok=True)

        dataloaders = [dataloader for dataloader, _ in train_objectives]

        # Use smart batching
        for dataloader in dataloaders:
            dataloader.collate_fn = self.smart_batching_collate

        loss_models = [loss for _, loss in train_objectives]
        for loss_model in loss_models:
            loss_model.to(self._target_device)

        self.best_score = -9999999

        if steps_per_epoch is None or steps_per_epoch == 0:
            steps_per_epoch = min(
                [len(dataloader) for dataloader in dataloaders])

        num_train_steps = int(steps_per_epoch * epochs)

        # Prepare logger
        if wandb_available and wandb_project_name:
            if not wandb.setup().settings.sweep_id:
                config = {
                    'epochs': epochs,
                    'steps_per_epoch': steps_per_epoch,
                    'scheduler': scheduler,
                    'warmup_steps': warmup_steps,
                    'weight_decay': weight_decay,
                    'evaluation_steps': evaluation_steps,
                    'output_path': output_path,
                    'save_best_model': save_best_model,
                    'max_grad_norm': max_grad_norm,
                    'use_amp': use_amp,
                }
                wandb.init(project=wandb_project_name,
                           config=config,
                           **wandb_config)
            wandb.watch(self)

        # SWA
        if use_swa:
            swa_model = AveragedModel(self)

        # Prepare optimizers
        optimizers = []
        schedulers = []
        for loss_model in loss_models:
            param_optimizer = list(loss_model.named_parameters())

            no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                weight_decay
            }, {
                'params': [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]

            optimizer = optimizer_class(optimizer_grouped_parameters,
                                        **optimizer_params)
            scheduler_obj = self._get_scheduler(optimizer,
                                                scheduler=scheduler,
                                                warmup_steps=warmup_steps,
                                                t_total=num_train_steps)

            optimizers.append(optimizer)
            schedulers.append(scheduler_obj)
        if use_swa:
            swa_scheduler = SWALR(optimizers[0],
                                  swa_lr=swa_lr,
                                  anneal_epochs=swa_anneal_epochs,
                                  anneal_strategy='linear')

        global_step = 0
        data_iterators = [iter(dataloader) for dataloader in dataloaders]

        num_train_objectives = len(train_objectives)

        skip_scheduler = False
        for epoch in trange(epochs,
                            desc="Epoch",
                            disable=not show_progress_bar):
            training_steps = 0

            for loss_model in loss_models:
                loss_model.zero_grad()
                loss_model.train()

            for _ in trange(steps_per_epoch,
                            desc="Iteration",
                            smoothing=0.05,
                            disable=not show_progress_bar):
                for train_idx in range(num_train_objectives):
                    loss_model = loss_models[train_idx]
                    optimizer = optimizers[train_idx]
                    scheduler = schedulers[train_idx]
                    data_iterator = data_iterators[train_idx]

                    try:
                        data = next(data_iterator)
                    except StopIteration:
                        data_iterator = iter(dataloaders[train_idx])
                        data_iterators[train_idx] = data_iterator
                        data = next(data_iterator)

                    features, labels = data

                    if use_amp:
                        with autocast():
                            loss_value = loss_model(features, labels)

                        scale_before_step = scaler.get_scale()
                        scaler.scale(loss_value).backward()
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(loss_model.parameters(),
                                                       max_grad_norm)
                        scaler.step(optimizer)
                        scaler.update()

                        skip_scheduler = scaler.get_scale(
                        ) != scale_before_step
                    else:
                        loss_value = loss_model(features, labels)
                        loss_value.backward()
                        torch.nn.utils.clip_grad_norm_(loss_model.parameters(),
                                                       max_grad_norm)
                        optimizer.step()

                    optimizer.zero_grad()

                    # if wandb init is called
                    if wandb_available and wandb.run is not None and (
                            training_steps + 1) % log_every == 0:
                        wandb.log(
                            {
                                loss_model.__class__.__name__:
                                loss_value.item(),
                                "lr": scheduler.get_last_lr()[0],
                            },
                            step=global_step)

                    if not skip_scheduler:
                        scheduler.step()

                training_steps += 1
                global_step += 1

                if evaluation_steps > 0 and training_steps % evaluation_steps == 0:
                    self._eval_during_training(evaluator, output_path,
                                               save_best_model, epoch,
                                               training_steps, global_step,
                                               callback)
                    for loss_model in loss_models:
                        loss_model.zero_grad()
                        loss_model.train()

            if use_swa and epoch > swa_epochs_start:
                swa_model.update_parameters(self)
                swa_scheduler.step()

            self._eval_during_training(evaluator, output_path, save_best_model,
                                       epoch, -1, global_step, callback)
        if use_swa:
            return swa_model
Ejemplo n.º 8
0
class HM:
    def __init__(self):

        if args.train is not None:
            self.train_tuple = get_tuple(args.train,
                                         bs=args.batch_size,
                                         shuffle=True,
                                         drop_last=False)

        if args.valid is not None:
            valid_bsize = 2048 if args.multiGPU else 50
            self.valid_tuple = get_tuple(args.valid,
                                         bs=valid_bsize,
                                         shuffle=False,
                                         drop_last=False)
        else:
            self.valid_tuple = None

        # Select Model, X is default
        if args.model == "X":
            self.model = ModelX(args)
        elif args.model == "V":
            self.model = ModelV(args)
        elif args.model == "U":
            self.model = ModelU(args)
        elif args.model == "D":
            self.model = ModelD(args)
        elif args.model == 'O':
            self.model = ModelO(args)
        else:
            print(args.model, " is not implemented.")

        # Load pre-trained weights from paths
        if args.loadpre is not None:
            self.model.load(args.loadpre)

        # GPU options
        if args.multiGPU:
            self.model.lxrt_encoder.multi_gpu()

        self.model = self.model.cuda()

        # Losses and optimizer
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.nllloss = nn.NLLLoss()

        if args.train is not None:
            batch_per_epoch = len(self.train_tuple.loader)
            self.t_total = int(batch_per_epoch * args.epochs // args.acc)
            print("Total Iters: %d" % self.t_total)

        def is_backbone(n):
            if "encoder" in n:
                return True
            elif "embeddings" in n:
                return True
            elif "pooler" in n:
                return True
            print("F: ", n)
            return False

        no_decay = ['bias', 'LayerNorm.weight']

        params = list(self.model.named_parameters())
        if args.reg:
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in params if is_backbone(n)],
                    "lr": args.lr
                },
                {
                    "params": [p for n, p in params if not is_backbone(n)],
                    "lr": args.lr * 500
                },
            ]

            for n, p in self.model.named_parameters():
                print(n)

            self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr)
        else:
            optimizer_grouped_parameters = [{
                'params':
                [p for n, p in params if not any(nd in n for nd in no_decay)],
                'weight_decay':
                args.wd
            }, {
                'params':
                [p for n, p in params if any(nd in n for nd in no_decay)],
                'weight_decay':
                0.0
            }]

            self.optim = AdamW(optimizer_grouped_parameters, lr=args.lr)

        if args.train is not None:
            self.scheduler = get_linear_schedule_with_warmup(
                self.optim, self.t_total * 0.1, self.t_total)

        self.output = args.output
        os.makedirs(self.output, exist_ok=True)

        # SWA Method:
        if args.contrib:
            self.optim = SWA(self.optim,
                             swa_start=self.t_total * 0.75,
                             swa_freq=5,
                             swa_lr=args.lr)

        if args.swa:
            self.swa_model = AveragedModel(self.model)
            self.swa_start = self.t_total * 0.75
            self.swa_scheduler = SWALR(self.optim, swa_lr=args.lr)

    def train(self, train_tuple, eval_tuple):

        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        print("Batches:", len(loader))

        self.optim.zero_grad()

        best_roc = 0.
        ups = 0

        total_loss = 0.

        for epoch in range(args.epochs):

            if args.reg:
                if args.model != "X":
                    print(self.model.model.layer_weights)

            id2ans = {}
            id2prob = {}

            for i, (ids, feats, boxes, sent,
                    target) in iter_wrapper(enumerate(loader)):

                if ups == args.midsave:
                    self.save("MID")

                self.model.train()

                if args.swa:
                    self.swa_model.train()

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.long(
                ).cuda()

                # Model expects visual feats as tuple of feats & boxes
                logit = self.model(sent, (feats, boxes))

                # Note: LogSoftmax does not change order, hence there should be nothing wrong with taking it as our prediction
                # In fact ROC AUC stays the exact same for logsoftmax / normal softmax, but logsoftmax is better for loss calculation
                # due to stronger penalization & decomplexifying properties (log(a/b) = log(a) - log(b))
                logit = self.logsoftmax(logit)
                score = logit[:, 1]

                if i < 1:
                    print(logit[0, :].detach())

                # Note: This loss is the same as CrossEntropy (We splitted it up in logsoftmax & neg. log likelihood loss)
                loss = self.nllloss(logit.view(-1, 2), target.view(-1))

                # Scaling loss by batch size, as we have batches with different sizes, since we do not "drop_last" & dividing by acc for accumulation
                # Not scaling the loss will worsen performance by ~2abs%
                loss = loss * logit.size(0) / args.acc
                loss.backward()

                total_loss += loss.detach().item()

                # Acts as argmax - extracting the higher score & the corresponding index (0 or 1)
                _, predict = logit.detach().max(1)
                # Getting labels for accuracy
                for qid, l in zip(ids, predict.cpu().numpy()):
                    id2ans[qid] = l
                # Getting probabilities for Roc auc
                for qid, l in zip(ids, score.detach().cpu().numpy()):
                    id2prob[qid] = l

                if (i + 1) % args.acc == 0:

                    nn.utils.clip_grad_norm_(self.model.parameters(),
                                             args.clip)

                    self.optim.step()

                    if (args.swa) and (ups > self.swa_start):
                        self.swa_model.update_parameters(self.model)
                        self.swa_scheduler.step()
                    else:
                        self.scheduler.step()
                    self.optim.zero_grad()

                    ups += 1

                    # Do Validation in between
                    if ups % 250 == 0:

                        log_str = "\nEpoch(U) %d(%d): Train AC %0.2f RA %0.4f LOSS %0.4f\n" % (
                            epoch, ups, evaluator.evaluate(id2ans) * 100,
                            evaluator.roc_auc(id2prob) * 100, total_loss)

                        # Set loss back to 0 after printing it
                        total_loss = 0.

                        if self.valid_tuple is not None:  # Do Validation
                            acc, roc_auc = self.evaluate(eval_tuple)
                            if roc_auc > best_roc:
                                best_roc = roc_auc
                                best_acc = acc
                                # Only save BEST when no midsave is specified to save space
                                #if args.midsave < 0:
                                #    self.save("BEST")

                            log_str += "\nEpoch(U) %d(%d): DEV AC %0.2f RA %0.4f \n" % (
                                epoch, ups, acc * 100., roc_auc * 100)
                            log_str += "Epoch(U) %d(%d): BEST AC %0.2f RA %0.4f \n" % (
                                epoch, ups, best_acc * 100., best_roc * 100.)

                        print(log_str, end='')

                        with open(self.output + "/log.log", 'a') as f:
                            f.write(log_str)
                            f.flush()

        if (epoch + 1) == args.epochs:
            if args.contrib:
                self.optim.swap_swa_sgd()

        self.save("LAST" + args.train)

    def predict(self, eval_tuple: DataTuple, dump=None, out_csv=True):

        dset, loader, evaluator = eval_tuple
        id2ans = {}
        id2prob = {}

        for i, datum_tuple in enumerate(loader):

            ids, feats, boxes, sent = datum_tuple[:4]

            self.model.eval()

            if args.swa:
                self.swa_model.eval()

            with torch.no_grad():

                feats, boxes = feats.cuda(), boxes.cuda()
                logit = self.model(sent, (feats, boxes))

                # Note: LogSoftmax does not change order, hence there should be nothing wrong with taking it as our prediction
                logit = self.logsoftmax(logit)
                score = logit[:, 1]

                if args.swa:
                    logit = self.swa_model(sent, (feats, boxes))
                    logit = self.logsoftmax(logit)

                _, predict = logit.max(1)

                for qid, l in zip(ids, predict.cpu().numpy()):
                    id2ans[qid] = l

                # Getting probas for Roc Auc
                for qid, l in zip(ids, score.cpu().numpy()):
                    id2prob[qid] = l

        if dump is not None:
            if out_csv == True:
                evaluator.dump_csv(id2ans, id2prob, dump)
            else:
                evaluator.dump_result(id2ans, dump)

        return id2ans, id2prob

    def evaluate(self, eval_tuple: DataTuple, dump=None):
        """Evaluate all data in data_tuple."""
        id2ans, id2prob = self.predict(eval_tuple, dump=dump)

        acc = eval_tuple.evaluator.evaluate(id2ans)
        roc_auc = eval_tuple.evaluator.roc_auc(id2prob)

        return acc, roc_auc

    def save(self, name):
        if args.swa:
            torch.save(self.swa_model.state_dict(),
                       os.path.join(self.output, "%s.pth" % name))
        else:
            torch.save(self.model.state_dict(),
                       os.path.join(self.output, "%s.pth" % name))

    def load(self, path):
        print("Load model from %s" % path)

        state_dict = torch.load("%s" % path)
        new_state_dict = {}
        for key, value in state_dict.items():
            # N_averaged is a key in SWA models we cannot load, so we skip it
            if key.startswith("n_averaged"):
                print("n_averaged:", value)
                continue
            # SWA Models will start with module
            if key.startswith("module."):
                new_state_dict[key[len("module."):]] = value
            else:
                new_state_dict[key] = value
        state_dict = new_state_dict
        self.model.load_state_dict(state_dict)
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--batch_size",
        default=8,
        type=int,
        help="batch size of both segmentation and classification training")
    parser.add_argument(
        "--seg_epoch",
        default=100,
        type=int,
        help="the number of epoch in the segmentation training")
    parser.add_argument(
        "--cls_epoch",
        default=20,
        type=int,
        help="the number of epoch in the classification training")
    parser.add_argument("--lr",
                        default=0.01,
                        type=float,
                        help="the learning rate of training")
    parser.add_argument("--swa_lr",
                        default=0.005,
                        type=float,
                        help="the stochastic learning rate of training")
    parser.add_argument(
        "--seg_weight",
        default=[0.1, 1],
        type=list,
        nargs='+',
        help="the weight of Binary Cross Entropy in the segmentation learning")
    parser.add_argument(
        "--cls_weight",
        default=[1, 1],
        type=list,
        nargs='+',
        help="the weight of Binary Cross Entropy in the classification learning"
    )
    parser.add_argument("--seed",
                        default=2021,
                        type=int,
                        help="the random seed")
    parser.add_argument(
        "--train_dir",
        default="/train_dir",
        type=str,
        help=
        "the train data directory. it consists of the both ng and ok directorys, and they have img and mask folders."
    )
    parser.add_argument(
        "--val_dir",
        default="/val_dir",
        type=str,
        help=
        "the validation data directory. it consists of the both ng and ok directorys, and they have img and mask folders."
    )

    args = parser.parse_args()

    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')
    segmentation_train = True
    classification_train = True

    train_dir = Path(args.train_dir)
    val_dir = Path(args.val_dir)

    train_ok_dir = str(train_dir / "ok")
    train_mask_dir = str(train_dir / "mask")
    train_ng_dir = str(train_dir / "ng")

    val_ok_dir = str(val_dir / "ok")
    val_mask_dir = str(val_dir / "mask")
    val_ng_dir = str(val_dir / "ng")

    seg_train_dataset = SegmentationDataset(img_dir=train_ng_dir,
                                            mask_dir=train_mask_dir,
                                            n_channels=3,
                                            classes=1,
                                            train=True)
    seg_val_dataset = SegmentationDataset(img_dir=val_ng_dir,
                                          mask_dir=val_mask_dir,
                                          n_channels=3,
                                          classes=1,
                                          train=False)

    cls_train_dataset = ClassificationDataset(ok_dir=train_ok_dir,
                                              ng_dir=train_ng_dir,
                                              n_channels=3,
                                              classes=1,
                                              train=True)
    cls_val_dataset = ClassificationDataset(ok_dir=val_ok_dir,
                                            ng_dir=val_ng_dir,
                                            n_channels=3,
                                            classes=1,
                                            train=False)

    seg_train_loader = DataLoader(seg_train_dataset,
                                  batch_size=8,
                                  shuffle=True)
    seg_val_loader = DataLoader(seg_val_dataset, batch_size=8, shuffle=True)
    cls_train_loader = DataLoader(cls_train_dataset,
                                  batch_size=8,
                                  shuffle=True)
    cls_val_loader = DataLoader(cls_val_dataset, batch_size=8, shuffle=True)

    my_model = DownconvUnet(in_channel=3, seg_classes=1, cls_classes=2)
    avg_model = AveragedModel(my_model)

    my_model.to(device)
    avg_model.to(device)

    with mlflow.start_run() as run:
        seg_args = Params(args.batch_size, args.seg_epoch, args.lr, args.seed,
                          args.seg_weight)
        cls_args = Params(args.batch_size, args.cls_epoch, args.lr, args.seed,
                          args.cls_weight)
        mode_list = ["seg", "cls"]
        for mode in mode_list:
            for key, value in vars(seg_args).items():
                mlflow.log_param(f"{mode}_{key}", value)

        # Segmentation train

        if segmentation_train:
            print("-" * 5 + "Segmentation training start" + "-" * 5)

            my_model.mode = 1
            train_metrics = Metrics()
            train_loss = 0.
            train_iou = 0.
            train_acc = 0.

            val_metrics = Metrics()
            val_loss = 0.
            val_iou = 0.
            val_acc = 0.

            my_model.train()

            optimizer = torch.optim.Adam(my_model.parameters(), lr=seg_args.lr)
            scheduler = CosineAnnealingLR(optimizer, T_max=100)
            bce = WeightedBCELoss(weight=seg_args.weight)
            swa_start = int(seg_args.num_epoch * 0.75)
            swa_scheduler = SWALR(optimizer,
                                  anneal_strategy='linear',
                                  anneal_epochs=swa_start,
                                  swa_lr=seg_args.swa_lr)

            for epoch in range(seg_args.num_epoch):
                for batch_idx, batch in enumerate(seg_train_loader):
                    batch = tuple(t.to(device) for t in batch)
                    seg_x, seg_y = batch

                    optimizer.zero_grad()

                    pred_y = my_model(seg_x)
                    loss = bce(pred_y, seg_y)
                    loss.backward()
                    optimizer.step()

                    train_loss += loss.item()
                    train_metrics.update(pred_y, seg_y, loss.item())
                    train_iou += train_metrics.iou
                    train_acc += train_metrics.acc

                    step = epoch * len(seg_train_loader) + batch_idx
                    for metric, value in vars(train_metrics).items():
                        mlflow.log_metric(f"seg_train_{metric}",
                                          value,
                                          step=step)

                train_loss /= len(seg_train_loader)
                train_iou /= len(seg_train_loader)
                train_acc /= len(seg_train_loader)

                my_model.eval()

                for batch_idx, batch in enumerate(seg_val_loader):
                    batch = tuple(t.to(device) for t in batch)
                    seg_x, seg_y = batch
                    pred_y = my_model(seg_x)

                    loss = bce(pred_y, seg_y)

                    val_loss += loss.item()
                    val_metrics.update(pred_y, seg_y, val_loss)
                    val_iou += val_metrics.iou
                    val_acc += val_metrics.acc

                    step = epoch * len(seg_val_loader) + batch_idx
                    for metric, value in vars(val_metrics).items():
                        mlflow.log_metric(f"seg_val_{metric}",
                                          value,
                                          step=step)

                val_loss /= len(seg_val_loader)
                val_iou /= len(seg_val_loader)
                val_acc /= len(seg_val_loader)

                print(f"Epoch {epoch + 1}:")
                print("-" * 10)
                print(
                    f"train_loss {train_loss:.3f}, train_iou: {train_iou:.3f}, "
                    f"train_accuracy: {train_acc:.3f}")
                print(f"val_loss {val_loss:.3f}, val_iou: {val_iou:.3f}, "
                      f"val_accuracy: {val_acc:.3f}")

                if epoch > swa_start:
                    print("Stochastic average start")
                    avg_model.update_parameters(my_model)
                    swa_scheduler.step()
                else:
                    scheduler.step()

            print("Segmentation train completed")

            # Classification train

            if classification_train:
                print("-" * 5 + "Classification training start" + "-" * 5)

                my_model.mode = 2

                train_metrics = Metrics()
                train_loss = 0.
                train_iou = 0.
                train_acc = 0.

                val_metrics = Metrics()
                val_loss = 0.
                val_iou = 0.
                val_acc = 0.

                my_model.train()

                optimizer = torch.optim.Adam(my_model.parameters(),
                                             lr=cls_args.lr)
                scheduler = CosineAnnealingLR(optimizer, T_max=100)
                bce = WeightedBCELoss(weight=cls_args.weight)
                swa_start = int(cls_args.num_epoch * 0.75)
                swa_scheduler = SWALR(optimizer,
                                      anneal_strategy='linear',
                                      anneal_epochs=swa_start,
                                      swa_lr=cls_args.swa_lr)

                for epoch in range(cls_args.num_epoch):
                    for batch_idx, batch in enumerate(cls_train_loader):
                        batch = tuple(t.to(device) for t in batch)
                        cls_x, cls_y = batch

                        optimizer.zero_grad()

                        pred_y = my_model(cls_x)
                        loss = bce(pred_y, cls_y)
                        loss.backward()
                        optimizer.step()

                        train_loss += loss.item()
                        train_metrics.update(pred_y, cls_y, train_loss)
                        train_acc += train_metrics.acc

                        step = epoch * len(seg_train_loader) + batch_idx
                        for metric, value in vars(train_metrics).items():
                            mlflow.log_metric(f"cls_train_{metric}",
                                              value,
                                              step=step)

                    train_loss /= len(seg_train_loader)
                    train_acc /= len(seg_train_loader)

                    my_model.eval()

                    for batch_idx, batch in enumerate(cls_val_loader):
                        batch = tuple(t.to(device) for t in batch)
                        cls_x, cls_y = batch
                        pred_y = my_model(cls_x)

                        loss = bce(pred_y, cls_y)

                        val_loss += loss.item()
                        val_metrics.update(pred_y, cls_y, loss.item())
                        val_acc += val_metrics.acc

                        step = epoch * len(seg_train_loader) + batch_idx
                        for metric, value in vars(val_metrics).items():
                            mlflow.log_metric(f"cls_train_{metric}",
                                              value,
                                              step=step)

                    val_loss /= len(seg_val_loader)
                    val_acc /= len(seg_val_loader)

                    print(f"Epoch {epoch + 1}:")
                    print("-" * 10)
                    print(
                        f"train_loss {train_loss:.3f}, train_iou: {train_iou:.3f}, "
                        f"train_accuracy: {train_acc:.3f}")
                    print(f"val_loss {val_loss:.3f}, val_iou: {val_iou:.3f}, "
                          f"val_accuracy: {val_acc:.3f}")

                print("Classification train completed")

                if epoch > swa_start:
                    print("Stochastic average start")
                    avg_model.update_parameters(my_model)
                    swa_scheduler.step()
                else:
                    scheduler.step()
    weight_path = "weights/donwconv_swa_weights.pth"
    torch.save(my_model.state_dict(), weight_path)
    print(f"model weight saved to {weight_path}")
Ejemplo n.º 10
0
    def train(self, train_dataset):
        train_loader = DataLoader(train_dataset,
                                  batch_size=self.batch_size,
                                  num_workers=self.num_workers,
                                  drop_last=False,
                                  shuffle=True)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
                                                               self.max_epochs,
                                                               eta_min=0.0005,
                                                               last_epoch=-1)
        swa_scheduler = SWALR(self.optimizer, swa_lr=0.05)
        online_swa = AveragedModel(self.online_network)
        predictor_swa = AveragedModel(self.predictor)

        swa_start = 200

        niter = 0
        model_checkpoints_folder = os.path.join(self.writer.log_dir,
                                                'checkpoints')

        self.initializes_target_network()
        self.optimizer.zero_grad()

        for epoch_counter in range(self.max_epochs):
            for (batch_view_1, batch_view_2), _ in train_loader:
                # print(batch_view_1.shape) # 256, 3, 96, 96
                batch_view_1 = batch_view_1.to(self.device)
                batch_view_2 = batch_view_2.to(self.device)

                print(batch_view_1.shape)

                loss = self.update(batch_view_1, batch_view_2)
                self.writer.add_scalar('loss', loss, global_step=niter)

                loss = loss / self.accumulation_steps  # Normalize our loss (if averaged)
                loss.backward()

                if (niter + 1) % self.accumulation_steps == 0:
                    # torch.nn.utils.clip_grad_norm_(list(self.online_network.parameters()) + list(self.predictor.parameters()), max_norm=0.5)
                    self.optimizer.step()
                    self.optimizer.zero_grad()

                    self._update_target_network_parameters(
                    )  # update the key encoder

                if epoch_counter > swa_start:
                    online_swa.update_parameters(self.online_network)
                    predictor_swa.update_parameters(self.predictor)
                    swa_scheduler.step()

                niter += 1

            print("End of epoch {}".format(epoch_counter))
            scheduler.step()
            self.m = 1 - (1 - self.m_initial) * (
                math.cos(math.pi * epoch_counter / self.max_epochs) + 1) / 2
            self.save_model(os.path.join(model_checkpoints_folder,
                                         'model.pth'))

        # Update bn statistics for the swa_model at the end
        for (batch_view_1, _), _ in train_loader:
            batch_view_1 = batch_view_1.to(self.device)
            _ = online_swa.forward(batch_view_1)

        # save checkpoints
        self.save_model(os.path.join(model_checkpoints_folder, 'model.pth'),
                        save_swa=True,
                        swa=online_swa)
Ejemplo n.º 11
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
    
    ## SWA
    swa_model = AveragedModel(model)
    swa_scheduler = SWALR(optimizer, swa_lr=args.learning_rate) # 1e-4

    # Train!
    print("***** Running training *****")
    print("  Num examples = %d", len(train_dataset))
    print("  Num Epochs = %d", args.num_train_epochs)
    print("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    print(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    print("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    print("  Total optimization steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            print("  Continuing training from checkpoint, will skip to saved global_step")
            print("  Continuing training from epoch %d", epochs_trained)
            print("  Continuing training from global step %d", global_step)
            print("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            print("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    # Added here for reproductibility
    set_seed(args)

    swa_start = t_total // args.num_train_epochs * (args.num_train_epochs-1) ## SWA
    print('\n swa_start =', swa_start)
    for _ in train_iterator:
        training_pbar = tqdm(total=len(train_dataset),
                         position=0, leave=True,
                         file=sys.stdout, bar_format="{l_bar}%s{bar}%s{r_bar}" % (Fore.GREEN, Fore.RESET))
        for step, batch in enumerate(train_dataloader):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }

            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()

            tr_loss += loss.item()
            training_pbar.update(batch[0].size(0)) # hiepnh
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

                optimizer.step()
                ## SWA
                if global_step >= swa_start:
                    swa_model.update_parameters(model)
                    swa_scheduler.step()
                else:
                    scheduler.step()
                # scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

                # Save model checkpoint
                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
                    # Take care of distributed/parallel training
                    model_to_save = model.module if hasattr(model, "module") else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    print("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    print("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                training_pbar.close() # hiepnh
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    ## SWA
    update_bn(train_dataloader, swa_model)

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step, swa_model
Ejemplo n.º 12
0
    def train(self, train,configs,valid=None,test=None,checkpoint=None):
        
        if valid is not None and test is not None:
            raise('Only supply validation or test data, not both!')
            return

        self.writer = SummaryWriter('../runs/' + configs['experiment_name'] + utils.get_time())
        self.network = self.network.to('cuda')
        
        if checkpoint is not None:
            epoch,accuracy_type,prev_accuracy =  (checkpoint[k] for k in ['epoch','accuracy_type','accuracy'])
            self.network.load_state_dict(checkpoint['net'])
            if accuracy_type == 'test':
                prev_test_accuracy = prev_accuracy
            else:
                prev_valid_accuracy = prev_accuracy
        else:
            epoch = 0
            prev_test_accuracy = 0
            prev_valid_accuracy = 0
    
        #self.network = torch.nn.DataParallel(self.network)
        torch.backends.cudnn.benchmark = True  # good for when input size doesn't change (32x32x3)

        batch_size = configs['batch_size']
        lr = configs['initial_lr']
        self.train_loader = torch.utils.data.DataLoader(train,batch_size,shuffle=True)

        self.optimizer =  torch.optim.SGD(self.network.parameters(), lr=lr,
                      momentum=0.9, weight_decay=5e-4)
        
        criterion = nn.CrossEntropyLoss()

        max_epoch = configs['max_epoch']
        #scheduler = CosineAnnealingLR(self.optimizer, 100,eta_min=0.0,verbose=True)
        swa_start = 225
        swa_scheduler = SWALR(self.optimizer, swa_lr=0.005)
        
        for epoch in range(max_epoch):
            self.network.train()
            with tqdm.tqdm(total = len(self.train_loader)) as epoch_pbar:
                train_total = 0
                train_correct = 0
                for batch_idx, (x_train, y_train) in enumerate(self.train_loader): # use train loader to enumerate over batches
                    # predict
                    self.optimizer.zero_grad() # set gradients to zero 
                    y_preds = self.network(x_train.to('cuda'))
                    loss = criterion(y_preds, y_train.to('cuda'))

                    # accuracy check
                    train_correct += self.score(y_preds,y_train.to('cuda')) # get running average of train accuracy
                    train_total += len(y_train)
                    train_accuracy = train_correct/train_total

                    # back propagation
                    loss.backward()
                    self.optimizer.step()     

                    # update progress           
                    epoch_pbar.set_description('[training %d/%d] Loss %.2f, Accuracy %.2f' % (epoch,max_epoch,loss, train_accuracy))
                    epoch_pbar.update(1)
        
            self.writer.add_scalar('learning rate',self.optimizer.param_groups[0]['lr'],epoch)
            if epoch > swa_start:
                swa_scheduler.step()
                self.network_swa.update_parameters(self.network)
            else:
                u_lr = self.schedule_lr(epoch,configs['learn_rate_schedule'])
                print("updated lr to %f" % u_lr)

            print('Epoch Done -> Train Total %d, Correct %d' % (train_total,train_correct))
            #scheduler.step(loss) # update learning rate
           
            #self.schedule_lr(epoch,configs['learn_rate_schedule'])
            self.writer.add_scalar('Loss/train',loss,epoch)
            self.writer.add_scalar('Accuracy/train',train_accuracy,epoch)
            

            # check validation accuracy
            if valid is not None:
                valid_accuracy,valid_correct,valid_total = self.evaluate_valid(valid)
                print("Valid Accuracy  %d/%d ---> %.2f | Best ---> %.2f"  % (valid_correct,valid_total,valid_accuracy,prev_valid_accuracy))
                # checkpoint model if accuracy has improved between epochs
                if prev_valid_accuracy  < valid_accuracy:
                    print('[checkpointing model]')
                    utils.checkpoint_model(self.network.state_dict(),self.model_configs['save_dir'],epoch,valid_accuracy,acc_type='valid')
                    prev_valid_accuracy = valid_accuracy

                self.writer.add_scalar('Accuracy/valid',valid_accuracy,epoch)

   
            # check test accuracy
            if test is not None:
                test_accuracy,test_correct,test_total = self.evaluate(test,False)
                #self.writer.add_figure('predictions vs. actual', fig,epoch)
                print("Test Accuracy %d/%d---> %.2f | Best ---> %.2f"  % (test_correct,test_total,test_accuracy,prev_test_accuracy) )
                
                # checkpoint model if accuracy has improved between epochs
                if prev_test_accuracy  < test_accuracy:
                    print('[checkpointing model]')
                    utils.checkpoint_model(self.network.state_dict(),self.model_configs['save_dir'],epoch,test_accuracy)
                    prev_test_accuracy = test_accuracy

                self.writer.add_scalar('Accuracy/test',test_accuracy,epoch)
                if epoch > swa_start:
                    swa_test_accuracy, swa_test_correct, swa_test_total = self.evaluate_swa(test,None)
                    print("SWA Test Accuracy %d/%d---> %.2f"  % (swa_test_correct,swa_test_total,swa_test_accuracy) )
                    self.writer.add_scalar('Accuracy/swa_test',swa_test_accuracy,epoch)
    
            self.writer.flush()
        torch.optim.swa_utils.update_bn(self.train_loader, self.network_swa) # update batch norm parameters in swa

        swa_test_accuracy, swa_test_correct, swa_test_total = self.evaluate_swa(test,None)
        print("[FINAL] SWA Test Accuracy %d/%d---> %.2f"  % (swa_test_correct,swa_test_total,swa_test_accuracy) )
        torch.save(self.network_swa.state_dict(), self.model_configs['save_dir'] + 'ckpt_swa.pth')
        self.writer.close()
        return
Ejemplo n.º 13
0
class Learner:
    def __init__(self, cfg_dir: str):
        self.cfg = get_conf(cfg_dir)
        self.logger = self.init_logger(self.cfg.logger)
        self.dataset = CustomDataset(**self.cfg.dataset)
        self.data = DataLoader(self.dataset, **self.cfg.dataloader)
        self.val_dataset = CustomDatasetVal(**self.cfg.val_dataset)
        self.val_data = DataLoader(self.val_dataset, **self.cfg.dataloader)
        self.logger.log_parameters({
            "tr_len": len(self.dataset),
            "val_len": len(self.val_dataset)
        })
        self.model = CustomModel(**self.cfg.model)
        self.model.apply(init_weights_normal)
        self.device = self.cfg.train_params.device
        self.model = self.model.to(device=self.device)
        if self.cfg.train_params.optimizer.lower() == "adam":
            self.optimizer = optim.Adam(self.model.parameters(),
                                        **self.cfg.adam)
        elif self.cfg.train_params.optimizer.lower() == "rmsprop":
            self.optimizer = optim.RMSprop(self.model.parameters(),
                                           **self.cfg.rmsprop)
        else:
            raise ValueError(
                f"Unknown optimizer {self.cfg.train_params.optimizer}")

        self.lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=100)
        self.criterion = nn.CrossEntropyLoss()

        if self.cfg.logger.resume:
            # load checkpoint
            print("Loading checkpoint")
            save_dir = self.cfg.directory.load
            checkpoint = load_checkpoint(save_dir, self.device)
            self.model.load_state_dict(checkpoint["model"])
            self.optimizer.load_state_dict(checkpoint["optimizer"])
            self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            self.epoch = checkpoint["epoch"]
            self.e_loss = checkpoint["e_loss"]
            self.best = checkpoint["best"]
            print(
                f"{datetime.now():%Y-%m-%d %H:%M:%S} "
                f"Loading checkpoint was successful, start from epoch {self.epoch}"
                f" and loss {self.best}")
        else:
            self.epoch = 1
            self.best = np.inf
            self.e_loss = []

        # initialize the early_stopping object
        self.early_stopping = EarlyStopping(
            patience=self.cfg.train_params.patience,
            verbose=True,
            path=self.cfg.directory.load,
            delta=self.cfg.train_params.early_stopping_delta,
        )

        # stochastic weight averaging
        self.swa_model = AveragedModel(self.model)
        self.swa_scheduler = SWALR(self.optimizer, **self.cfg.SWA)

    def train(self):
        while self.epoch <= self.cfg.train_params.epochs:
            running_loss = []
            self.model.train()

            bar = tqdm(
                enumerate(self.data),
                desc=f"Epoch {self.epoch}/{self.cfg.train_params.epochs}")
            for idx, (x, y) in bar:
                self.optimizer.zero_grad()
                # move data to device
                x = x.to(device=self.device)
                y = y.to(device=self.device)

                # forward, backward
                out = self.model(x)
                loss = self.criterion(out, y)
                loss.backward()
                # check grad norm for debugging
                grad_norm = check_grad_norm(self.model)
                # update
                self.optimizer.step()

                running_loss.append(loss.item())

                bar.set_postfix(loss=loss.item(), Grad_Norm=grad_norm)

                self.logger.log_metrics({
                    "epoch": self.epoch,
                    "batch": idx,
                    "loss": loss.item(),
                    "GradNorm": grad_norm,
                })

            bar.close()
            if self.epoch > self.cfg.train_params.swa_start:
                self.swa_model.update_parameters(self.model)
                self.swa_scheduler.step()
            else:
                self.lr_scheduler.step()

            # validate on val set
            val_loss, t = self.validate()
            t /= len(self.val_dataset)

            # average loss for an epoch
            self.e_loss.append(np.mean(running_loss))  # epoch loss
            print(
                f"{datetime.now():%Y-%m-%d %H:%M:%S} Epoch {self.epoch} summary: train Loss: {self.e_loss[-1]:.2f} \t| Val loss: {val_loss:.2f}"
                f"\t| time: {t:.3f} seconds")

            self.logger.log_metrics({
                "epoch": self.epoch,
                "epoch_loss": self.e_loss[-1],
                "val_loss": val_loss,
                "time": t,
            })

            # early_stopping needs the validation loss to check if it has decreased,
            # and if it has, it will make a checkpoint of the current model
            self.early_stopping(val_loss, self.model)

            if self.early_stopping.early_stop:
                print("Early stopping")
                self.save()
                break

            if self.epoch % self.cfg.train_params.save_every == 0:
                self.save()

            gc.collect()
            self.epoch += 1

        # Update bn statistics for the swa_model at the end
        if self.epoch >= self.cfg.train_params.swa_start:
            torch.optim.swa_utils.update_bn(self.data, self.swa_model)
            self.save(name=self.cfg.directory.model_name + "-final" +
                      str(self.epoch) + "-swa")

        macs, params = op_counter(self.model, sample=x)
        print(macs, params)
        self.logger.log_metrics({"GFLOPS": macs[:-1], "#Params": params[:-1]})
        print("Training Finished!")

    @timeit
    @torch.no_grad()
    def validate(self):

        self.model.eval()

        running_loss = []

        for idx, (x, y) in tqdm(enumerate(self.val_data), desc="Validation"):
            # move data to device
            x = x.to(device=self.device)
            y = y.to(device=self.device)

            # forward, backward
            if self.epoch > self.cfg.train_params.swa_start:
                # Update bn statistics for the swa_model
                torch.optim.swa_utils.update_bn(self.data, self.swa_model)
                out = self.swa_model(x)
            else:
                out = self.model(x)

            loss = self.criterion(out, y)
            running_loss.append(loss.item())

        # average loss
        loss = np.mean(running_loss)

        return loss

    def init_logger(self, cfg):
        logger = None
        # Check to see if there is a key in environment:
        EXPERIMENT_KEY = cfg.experiment_key

        # First, let's see if we continue or start fresh:
        CONTINUE_RUN = cfg.resume
        if (EXPERIMENT_KEY is not None):
            # There is one, but the experiment might not exist yet:
            api = comet_ml.API()  # Assumes API key is set in config/env
            try:
                api_experiment = api.get_experiment_by_id(EXPERIMENT_KEY)
            except Exception:
                api_experiment = None
            if api_experiment is not None:
                CONTINUE_RUN = True
                # We can get the last details logged here, if logged:
                # step = int(api_experiment.get_parameters_summary("batch")["valueCurrent"])
                # epoch = int(api_experiment.get_parameters_summary("epochs")["valueCurrent"])

        if CONTINUE_RUN:
            # 1. Recreate the state of ML system before creating experiment
            # otherwise it could try to log params, graph, etc. again
            # ...
            # 2. Setup the existing experiment to carry on:
            logger = comet_ml.ExistingExperiment(
                previous_experiment=EXPERIMENT_KEY,
                log_env_details=True,  # to continue env logging
                log_env_gpu=True,  # to continue GPU logging
                log_env_cpu=True,  # to continue CPU logging
                auto_histogram_weight_logging=True,
                auto_histogram_gradient_logging=True,
                auto_histogram_activation_logging=True)
            # Retrieved from above APIExperiment
            # self.logger.set_epoch(epoch)

        else:
            # 1. Create the experiment first
            #    This will use the COMET_EXPERIMENT_KEY if defined in env.
            #    Otherwise, you could manually set it here. If you don't
            #    set COMET_EXPERIMENT_KEY, the experiment will get a
            #    random key!
            logger = comet_ml.Experiment(
                disabled=cfg.disabled,
                project_name=cfg.project,
                auto_histogram_weight_logging=True,
                auto_histogram_gradient_logging=True,
                auto_histogram_activation_logging=True)
            logger.add_tags(cfg.tags.split())
            logger.log_parameters(self.cfg)

        return logger

    def save(self, name=None):
        checkpoint = {
            "epoch": self.epoch,
            "model": self.model.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "lr_scheduler": self.lr_scheduler.state_dict(),
            "best": self.best,
            "e_loss": self.e_loss
        }

        if name is None and self.epoch >= self.cfg.train_params.swa_start:
            save_name = self.cfg.directory.model_name + str(
                self.epoch) + "-swa"
            checkpoint['model-swa'] = self.swa_model.state_dict()

        elif name is None:
            save_name = self.cfg.directory.model_name + str(self.epoch)

        else:
            save_name = name

        if self.e_loss[-1] < self.best:
            self.best = self.e_loss[-1]
            checkpoint["best"] = self.best
            save_checkpoint(checkpoint, True, self.cfg.directory.save,
                            save_name)
        else:
            save_checkpoint(checkpoint, False, self.cfg.directory.save,
                            save_name)
Ejemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser(description="Trains the network.")
    parser.add_argument("train", help="Training data (.bin or .binpack)")
    parser.add_argument("val", help="Validation data (.bin or .binpack)")

    parser.add_argument("--tune",
                        action="store_true",
                        help="automated LR search")
    parser.add_argument(
        "--save",
        action="store_true",
        help="save after every training epoch (default = False)")
    parser.add_argument("--experiment",
                        default="1",
                        type=str,
                        help="specify the experiment id")
    parser.add_argument("--py-data",
                        action="store_true",
                        help="Use python data loader (default=False)")
    parser.add_argument(
        "--lambda",
        default=1.0,
        type=float,
        dest='lambda_',
        help=
        "lambda=1.0 = train on evaluations, lambda=0.0 = train on game results, interpolates between (default=1.0)."
    )
    parser.add_argument(
        "--num-workers",
        default=1,
        type=int,
        dest='num_workers',
        help=
        "Number of worker threads to use for data loading. Currently only works well for binpack."
    )
    parser.add_argument(
        "--batch-size",
        default=-1,
        type=int,
        dest='batch_size',
        help=
        "Number of positions per batch / per iteration. Default on GPU = 8192 on CPU = 128."
    )
    parser.add_argument(
        "--threads",
        default=-1,
        type=int,
        dest='threads',
        help="Number of torch threads to use. Default automatic (cores) .")
    parser.add_argument("--seed",
                        default=42,
                        type=int,
                        dest='seed',
                        help="torch seed to use.")
    parser.add_argument(
        "--smart-fen-skipping",
        action='store_true',
        dest='smart_fen_skipping',
        help=
        "If enabled positions that are bad training targets will be skipped during loading. Default: False"
    )
    parser.add_argument(
        "--random-fen-skipping",
        default=0,
        type=int,
        dest='random_fen_skipping',
        help=
        "skip fens randomly on average random_fen_skipping before using one.")
    parser.add_argument(
        "--resume-from-model",
        dest='resume_from_model',
        help="Initializes training using the weights from the given .pt model")

    features.add_argparse_args(parser)
    args = parser.parse_args()

    print("Training with {} validating with {}".format(args.train, args.val))

    torch.manual_seed(123)
    torch.cuda.manual_seed(123)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    batch_size = args.batch_size
    if batch_size <= 0:
        batch_size = 128 if args.gpus == 0 else 8192
    print('Using batch size {}'.format(batch_size))

    print('Smart fen skipping: {}'.format(args.smart_fen_skipping))
    print('Random fen skipping: {}'.format(args.random_fen_skipping))

    if args.threads > 0:
        print('limiting torch to {} threads.'.format(args.threads))
        t_set_num_threads(args.threads)

    feature_set = features.get_feature_set_from_name(args.features)

    if args.py_data:
        print('Using python data loader')
        train_data, val_data = data_loader_py(args.train, args.val, batch_size,
                                              feature_set, 'cuda:0')

    else:
        print('Using c++ data loader')
        train_data, val_data = data_loader_cc(
            args.train, args.val, feature_set, args.num_workers, batch_size,
            args.smart_fen_skipping, args.random_fen_skipping, 'cuda:0')

    print("Feature set: {}".format(feature_set.name))
    print("Num real features: {}".format(feature_set.num_real_features))
    print("Num virtual features: {}".format(feature_set.num_virtual_features))
    print("Num features: {}".format(feature_set.num_features))

    START_EPOCH = 0
    NUM_EPOCHS = 150
    SWA_START = int(0.75 * NUM_EPOCHS)

    LEARNING_RATE = 5e-4
    DECAY = 0
    EPS = 1e-7

    best_loss = 1000
    is_best = False

    early_stopping_delay = 30
    early_stopping_count = 0
    early_stopping_flag = False

    summary_location = 'logs/nnue_experiment_' + args.experiment
    save_location = '/home/esigelec/PycharmProjects/nnue-pytorch/save_models/' + args.experiment

    writer = SummaryWriter(summary_location)

    nnue = M.NNUE(feature_set=feature_set, lambda_=args.lambda_, s=1)

    train_params = [{
        'params': nnue.get_1xlr(),
        'lr': LEARNING_RATE
    }, {
        'params': nnue.get_10xlr(),
        'lr': LEARNING_RATE * 10.0
    }]

    optimizer = ranger.Ranger(train_params,
                              lr=LEARNING_RATE,
                              eps=EPS,
                              betas=(0.9, 0.999),
                              weight_decay=DECAY)

    if args.resume_from_model is not None:
        nnue, optimizer, START_EPOCH = load_ckp(args.resume_from_model, nnue,
                                                optimizer)
        nnue.set_feature_set(feature_set)
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.1,
                                                           patience=7,
                                                           cooldown=1,
                                                           min_lr=1e-7,
                                                           verbose=True)
    swa_scheduler = SWALR(optimizer, annealing_epochs=5, swa_lr=[5e-5, 1e-4])

    nnue = nnue.cuda()
    swa_nnue = AveragedModel(nnue)

    for epoch in range(START_EPOCH, NUM_EPOCHS):

        nnue.train()

        train_interval = 100
        loss_f_sum_interval = 0.0
        loss_f_sum_epoch = 0.0
        loss_v_sum_epoch = 0.0

        if early_stopping_flag:
            print("early end of training at epoch" + str(epoch))
            break

        for batch_idx, batch in enumerate(train_data):

            batch = [_data.cuda() for _data in batch]
            us, them, white, black, outcome, score = batch

            optimizer.zero_grad()
            output = nnue(us, them, white, black)

            loss = nnue_loss(output, outcome, score, args.lambda_)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(nnue.parameters(), 0.5)
            optimizer.step()

            loss_f_sum_interval += loss.float()
            loss_f_sum_epoch += loss.float()

            if batch_idx % train_interval == train_interval - 1:

                writer.add_scalar('train_loss',
                                  loss_f_sum_interval / train_interval,
                                  epoch * len(train_data) + batch_idx)

                loss_f_sum_interval = 0.0

        print("Epoch #{}\t Train_Loss: {:.8f}\t".format(
            epoch, loss_f_sum_epoch / len(train_data)))

        if epoch % 1 == 0 or (epoch + 1) == NUM_EPOCHS:

            with torch.no_grad():
                nnue.eval()
                for batch_idx, batch in enumerate(val_data):
                    batch = [_data.cuda() for _data in batch]
                    us, them, white, black, outcome, score = batch

                    _output = nnue(us, them, white, black)
                    loss_v = nnue_loss(_output, outcome, score, args.lambda_)
                    loss_v_sum_epoch += loss_v.float()

            if epoch > SWA_START:
                print("swa_mode")
                swa_nnue.update_parameters(nnue)
                swa_scheduler.step()
                checkpoint = {
                    'epoch': epoch + 1,
                    'state_dict': swa_nnue.state_dict(),
                    'optimizer': optimizer.state_dict()
                }
                save_ckp(checkpoint, save_location, 'swa_nnue.pt')

            else:

                scheduler.step(loss_v_sum_epoch / len(val_data))

                if loss_v_sum_epoch / len(val_data) <= best_loss:
                    best_loss = loss_v_sum_epoch / len(val_data)
                    is_best = True
                    early_stopping_count = 0
                else:
                    early_stopping_count += 1
                if early_stopping_delay == early_stopping_count:
                    early_stopping_flag = True

                if is_best:
                    checkpoint = {
                        'epoch': epoch + 1,
                        'state_dict': nnue.state_dict(),
                        'optimizer': optimizer.state_dict()
                    }
                    save_ckp(checkpoint, save_location)
                    is_best = False

            writer.add_scalar('val_loss', loss_v_sum_epoch / len(val_data),
                              epoch * len(train_data) + batch_idx)

            print("Epoch #{}\tVal_Loss: {:.8f}\t".format(
                epoch, loss_v_sum_epoch / len(val_data)))

    loss_v_sum_epoch = 0.0

    with torch.no_grad():
        swa_nnue.eval()
        for batch_idx, batch in enumerate(val_data):
            batch = [_data.cuda() for _data in batch]
            us, them, white, black, outcome, score = batch

            _output = swa_nnue(us, them, white, black)
            loss_v = nnue_loss(_output, outcome, score, args.lambda_)
            loss_v_sum_epoch += loss_v.float()

    print("Val_Loss: {:.8f}\t".format(loss_v_sum_epoch / len(val_data)))

    writer.close()
Ejemplo n.º 15
0
class Re_pl(pl.LightningModule):
    def __init__(self, re_dict, *args,
                 **kwargs):  #*args, **kwargs hparams, steps_per_epoch
        super().__init__()
        self.save_hyperparameters(re_dict)
        self.save_hyperparameters()
        #self.hparams = hparams
        self.swa_model = None

        self.network = Re_model(self.hparams["model"])
        self.learning_params = self.hparams["training"]

        self.swa_mode = False

        self.criterion = nn.MSELoss()

    def forward(self, x):
        if not self.swa_mode:
            return self.network(x)  #.float())
        else:
            return self.swa_model(x)  #.float())

    def configure_optimizers(self):
        if self.learning_params["optimizer"] == "belief":
            optimizer = AdaBelief(
                self.parameters(),
                lr=self.learning_params["lr"],
                eps=self.learning_params["eplison_belief"],
                weight_decouple=self.learning_params["weight_decouple"],
                weight_decay=self.learning_params["weight_decay"],
                rectify=self.learning_params["rectify"])
        elif self.learning_params["optimizer"] == "ranger_belief":
            optimizer = RangerAdaBelief(
                self.parameters(),
                lr=self.learning_params["lr"],
                eps=self.learning_params["eplison_belief"],
                weight_decouple=self.learning_params["weight_decouple"],
                weight_decay=self.learning_params["weight_decay"],
            )
        elif self.learning_params["optimizer"] == "adam":
            optimizer = torch.optim.Adam(self.parameters(),
                                         lr=self.learning_params["lr"])
        elif self.learning_params["optimizer"] == "adamW":
            optimizer = torch.optim.AdamW(self.parameters(),
                                          lr=self.learning_params["lr"])

        if self.learning_params["add_sch"]:
            lr_scheduler = {
                'scheduler':
                torch.optim.lr_scheduler.OneCycleLR(
                    optimizer,
                    max_lr=self.learning_params["lr"],
                    steps_per_epoch=self.hparams.
                    steps_per_epoch,  #int(len(train_loader))
                    epochs=self.learning_params["epochs"],
                    anneal_strategy='linear'),
                'name':
                'lr_scheduler_lr',
                'interval':
                'step',  # or 'epoch'
                'frequency':
                1,
            }
            print("sch added")
            return [optimizer], [lr_scheduler]

        return optimizer

    def training_step(self, batch, batch_idx):
        #also Manual optimization exist
        images, landmarks = batch
        landmarks = landmarks.view(landmarks.size(0), -1)

        predictions = self(images)

        loss = self.criterion(predictions, landmarks)

        self.log('train_loss', loss, on_step=True, on_epoch=True,
                 logger=True)  # prog_bar=True
        return loss

    #copied
    def get_lr_inside(self, optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    def training_epoch_end(self, outputs):
        self.log('epoch_now',
                 self.current_epoch,
                 on_step=False,
                 on_epoch=True,
                 logger=True)
        (oppp) = self.optimizers(use_pl_optimizer=True)
        self.log('lr_now',
                 self.get_lr_inside(oppp),
                 on_step=False,
                 on_epoch=True,
                 logger=True)
        # https://github.com/PyTorchLightning/pytorch-lightning/issues/3095
        if self.learning_params["swa"] and (
                self.current_epoch >= self.learning_params["swa_start_epoch"]):
            if self.swa_model is None:
                (optimizer) = self.optimizers(use_pl_optimizer=True)
                print("creating_swa")
                self.swa_model = AveragedModel(self.network)
                self.new_scheduler = SWALR(
                    optimizer,
                    anneal_strategy="linear",
                    anneal_epochs=5,
                    swa_lr=self.learning_params["swa_lr"])
            # https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
            self.swa_model.update_parameters(self.network)
            self.new_scheduler.step()

    def change_for_swa(self, loader):
        print("will it work?")
        torch.optim.swa_utils.update_bn(loader, self.swa_model)
        self.swa_mode = True
        return

    def validation_step(self, batch, batch_idx):

        images, landmarks = batch
        landmarks = landmarks.view(landmarks.size(0), -1)

        predictions = self(images)

        loss = self.criterion(predictions, landmarks)

        self.log('val_loss', loss, on_step=False, on_epoch=True,
                 logger=True)  # prog_bar=True

        return {'val_loss': loss}

    def test_step(self, batch, batch_idx):
        images, landmarks = batch
        landmarks = landmarks.view(landmarks.size(0), -1)

        predictions = self(images)

        loss = self.criterion(predictions, landmarks)

        self.log('test_loss', loss, on_step=False, on_epoch=True,
                 logger=True)  #prog_bar=True,

        return {'test_loss': loss}
Ejemplo n.º 16
0
def main():
    os.makedirs(SAVEPATH, exist_ok=True)
    print('save path:', SAVEPATH)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('device:', device)

    print('weight_decay:', WEIGHTDECAY)
    print('momentum:', MOMENTUM)
    print('batch_size:', BATCHSIZE)
    print('lr:', LR)
    print('epoch:', EPOCHS)
    print('Label smoothing:', LABELSMOOTH)
    print('Stochastic Weight Averaging:', SWA)
    if SWA:
        print('Swa lr:', SWA_LR)
        print('Swa start epoch:', SWA_START)
    print('Cutout augmentation:', CUTOUT)
    if CUTOUT:
        print('Cutout size:', CUTOUTSIZE)
    print('Activation:', ACTIVATION)

    # get model
    model = get_seresnet_cifar(activation=ACTIVATION)

    # get loss function
    if LABELSMOOTH:
        criterion = LabelSmoothingLoss(classes=10, smoothing=0.1)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=LR,
                                momentum=MOMENTUM,
                                weight_decay=WEIGHTDECAY,
                                nesterov=True)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,
                                                           T_max=EPOCHS,
                                                           eta_min=0)

    model = model.to(device)
    criterion = criterion.to(device)

    # Check number of parameters your model
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print(f"Number of parameters: {pytorch_total_params}")
    if int(pytorch_total_params) > 2000000:
        print('Your model has the number of parameters more than 2 millions..')
        return

    if SWA:
        # apply swa
        swa_model = AveragedModel(model)
        swa_scheduler = SWALR(optimizer, swa_lr=SWA_LR)
        swa_total_params = sum(p.numel() for p in swa_model.parameters())
        print(f"Swa parameters: {swa_total_params}")

    # cinic mean, std
    normalize = transforms.Normalize(mean=[0.47889522, 0.47227842, 0.43047404],
                                     std=[0.24205776, 0.23828046, 0.25874835])

    if CUTOUT:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize,
            Cutout(size=CUTOUTSIZE)
        ])
    else:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ])

    train_dataset = torchvision.datasets.ImageFolder('/content/train',
                                                     transform=train_transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=BATCHSIZE,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)

    # colab reload
    start_epoch = 0
    if os.path.isfile(os.path.join(SAVEPATH, 'latest_checkpoint.pth')):
        checkpoint = torch.load(os.path.join(SAVEPATH,
                                             'latest_checkpoint.pth'))
        start_epoch = checkpoint['epoch']
        scheduler.load_state_dict(checkpoint['scheduler'])
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if SWA:
            swa_scheduler.load_state_dict(checkpoint['swa_scheduler'])
            swa_model.load_state_dict(checkpoint['swa_model'])
        print(start_epoch, 'load parameter')

    for epoch in range(start_epoch, EPOCHS):
        print("\n----- epoch: {}, lr: {} -----".format(
            epoch, optimizer.param_groups[0]["lr"]))

        # train for one epoch
        start_time = time.time()
        train(train_loader, epoch, model, optimizer, criterion, device)
        elapsed_time = time.time() - start_time
        print('==> {:.2f} seconds to train this epoch\n'.format(elapsed_time))

        # learning rate scheduling
        if SWA and epoch > SWA_START:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            scheduler.step()

        if SWA:
            checkpoint = {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'swa_model': swa_model.state_dict(),
                'swa_scheduler': swa_scheduler.state_dict()
            }
        else:
            checkpoint = {
                'epoch': epoch + 1,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }
        torch.save(checkpoint, os.path.join(SAVEPATH, 'latest_checkpoint.pth'))
        if epoch % 10 == 0:
            torch.save(checkpoint,
                       os.path.join(SAVEPATH, '%d_checkpoint.pth' % epoch))
Ejemplo n.º 17
0
class ModelTrainer:
    def __init__(self, config: DNNConfig):
        self.config = config
        self.epochs = config.epoch_num
        self.device = config.device

        self.model = tmp_model
        #self.criterion = CustomLoss()

        self.criterion = nn.MSELoss()

        optimizer_kwargs = {
            'lr': config.lr,
            'weight_decay': config.weight_decay
        }
        self.sam = config.issam
        self.optimizer = make_optimizer(self.model,
                                        optimizer_kwargs,
                                        optimizer_name=config.optimizer_name,
                                        sam=config.issam)
        self.scheduler_name = config.scheduler_name
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer=self.optimizer, T_max=config.T_max)

        self.isswa = config.getattr('isswa', False)
        self.swa_start = config.getattr('swa_start', 0)

        if config.isswa:
            self.swa_model = AveragedModel(self.model)
            self.swa_scheduler = SWALR(self.optimizer, swa_lr=0.025)

        #self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.optimizer,
        #                                                      mode=config.mode, factor=config.factor)

        self.loss_log = {
            'train_loss': [],
            'train_score': [],
            'valid_loss': [],
            'valid_score': []
        }

    def load_params(self, state_dict):
        self.model.load_state_dict(state_dict)

    def save_loss_log(self, save_name: str):
        with open(save_name, 'wb') as f:
            pickle.dump(self.loss_log, f)

    def loss_fn(self, y, preds):
        criterion = nn.MSELoss()
        loss = criterion(y, preds)
        return loss

    def valid_fn(self, y, preds):
        score = metrics.f1_score(y, preds, average='macro')
        return score

    def reshape_targets(self, y):
        return y.squeeze(1)

    def train(self, trn_dataloader, epoch):
        self.model.to(self.device)
        self.model.train()
        preds, targets, losses = [], [], []
        with tqdm(total=len(trn_dataloader), unit="batch") as pbar:
            pbar.set_description(f"[train] Epoch {epoch+1}/{self.epochs}")
            for data in trn_dataloader:
                x = data['input']
                y = self.reshape_targets(data['label'])
                output = self.model(x.to(self.device))
                if not self.sam:
                    loss = self.loss_fn(output, y.to(self.device))
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                else:
                    loss = self.loss_fn(output, y.to(self.device))
                    loss.backward()
                    self.optimizer.first_step(zero_grad=True)
                    # second forward-backward pass, make sure to do a full forward pass
                    self.loss_fn(self.model(x.to(self.device)),
                                 y.to(self.device)).backward()
                    self.optimizer.second_step(zero_grad=True)

                if self.scheduler_name != 'ReduceLROnPlateau':
                    if self.isswa and self.swa_start <= epoch:  # if swa phase do nothing
                        pass
                    else:
                        self.scheduler.step()

                losses.append(loss.item())
                preds += output.detach().cpu().tolist(
                )  #torch.argmax(output, dim=1).detach().cpu().tolist()
                targets += data.y.detach().cpu().tolist()
                batch_score = self.valid_fn(np.array(targets), np.array(preds))
                pbar.set_postfix(loss=np.mean(losses), score=batch_score)
                pbar.update(1)

        if self.scheduler_name == 'ReduceLROnPlateau':
            if self.isswa and self.swa_start <= epoch:  # if swa phase, update parameters of swa_model
                self.swa_model.to(self.device)
                self.swa_model.update_parameters(self.model)
                self.swa_scheduler.step()
                self.swa_model.to('cpu')
            else:
                self.scheduler.step(batch_score)

        self.loss_log['train_loss'].append(np.mean(losses))
        self.loss_log['train_score'].append(batch_score)
        self.model.to('cpu')
        self.model.eval()

    def eval(self, val_dataloader, epoch):
        self.model.to(self.device)
        self.model.eval()
        preds, targets, losses = [], [], []
        with tqdm(total=len(val_dataloader), unit="batch") as pbar:
            pbar.set_description(f"[eval]  Epoch {epoch+1}/{self.epochs}")
            with torch.no_grad():
                for data in val_dataloader:
                    x = data['input']
                    y = self.reshape_targets(data['label'])
                    output = self.model(x.to(self.device)).cpu()
                    target = y.cpu()
                    losses.append(self.loss_fn(target, output))
                    preds += output.tolist(
                    )  #torch.argmax(output, dim=1).detach().cpu().tolist()
                    targets += target.tolist()
                    batch_score = self.valid_fn(np.array(targets),
                                                np.array(preds))
                    pbar.set_postfix(loss=np.mean(losses), score=batch_score)
                    pbar.update(1)
        self.loss_log['valid_loss'].append(np.mean(losses))
        self.loss_log['valid_score'].append(batch_score)
        self.model.to('cpu')

    def inference_swa(self, train_loader, data_loader):
        self.swa_model.to(self.device)
        self.swa_model.eval()
        torch.optim.swa_utils.update_bn(train_loader, self.swa_model)
        preds = []
        with tqdm(total=len(data_loader), unit="batch") as pbar:
            pbar.set_description(f"[inference]")
            with torch.no_grad():
                for data in data_loader:
                    x = data['input']
                    output = self.swa_model(x.to(self.device))
                    preds.append(output.cpu().numpy())
                    #hidden_state.append(hidden.cpu().numpy())
                    pbar.update(1)
        return np.vstack(preds)

    def inference(self, data_loader):
        self.model.to(self.device)
        self.model.eval()
        preds, hidden_state = [], []
        with tqdm(total=len(data_loader), unit="batch") as pbar:
            pbar.set_description(f"[inference]")
            with torch.no_grad():
                for data in data_loader:
                    x = data['input']
                    output = self.model(x.to(self.device))
                    preds.append(output.cpu().numpy())
                    #hidden_state.append(hidden.cpu().numpy())
                    pbar.update(1)
        return np.vstack(preds)  #, np.vstack(hidden_state)
Ejemplo n.º 18
0
class Learner:
    def __init__(self, cfg_dir: str):
        # load config file and initialize the logger and the device
        self.cfg = get_conf(cfg_dir)
        self.logger = self.init_logger(self.cfg.logger)
        self.device = self.init_device()
        # creating dataset interface and dataloader for trained data
        self.data, self.val_data = self.init_dataloader()
        # create model and initialize its weights and move them to the device
        self.model = self.init_model()
        # initialize the optimizer
        self.optimizer, self.lr_scheduler = self.init_optimizer()
        # define loss function
        self.criterion = torch.nn.CrossEntropyLoss()
        # if resuming, load the checkpoint
        self.if_resume()

        # initialize the early_stopping object
        self.early_stopping = EarlyStopping(
            patience=self.cfg.train_params.patience,
            verbose=True,
            delta=self.cfg.train_params.early_stopping_delta,
        )

        # stochastic weight averaging
        if self.cfg.train_params.epochs > self.cfg.train_params.swa_start:
            self.swa_model = AveragedModel(self.model)
            self.swa_scheduler = SWALR(self.optimizer, **self.cfg.SWA)

    def train(self):
        """Trains the model"""
        # a variable to print the start of SWA
        print_swa_start = True

        while self.epoch <= self.cfg.train_params.epochs:
            running_loss = []

            bar = tqdm(
                self.data,
                desc=f"Epoch {self.epoch:03}/{self.cfg.train_params.epochs:03}, training: ",
            )
            for data in bar:
                self.iteration += 1
                (loss, grad_norm), t_train = self.forward_batch(data)
                t_train /= self.data.batch_size
                running_loss.append(loss)

                bar.set_postfix(loss=loss, Grad_Norm=grad_norm, Time=t_train)

                self.logger.log_metrics(
                    {
                        "batch_loss": loss,
                        "grad_norm": grad_norm,
                    },
                    epoch=self.epoch,
                    step=self.iteration,
                )

            bar.close()
            # update SWA model parameters
            if self.epoch > self.cfg.train_params.swa_start:
                if print_swa_start:
                    print(f"Epoch {self.epoch:03}, step {self.iteration:05}, starting SWA!")
                    # print only once
                    print_swa_start = False

                self.swa_model.update_parameters(self.model)
                self.swa_scheduler.step()
            else:
                self.lr_scheduler.step()

            # validate on val set
            val_loss, t = self.validate()
            t /= len(self.val_data.dataset)

            # average loss for an epoch
            self.e_loss.append(np.mean(running_loss))  # epoch loss
            print(
                f"{datetime.now():%Y-%m-%d %H:%M:%S} Epoch {self.epoch:03}, " +
                f"Iteration {self.iteration:05} summary: train Loss: " +
                f"{self.e_loss[-1]:.2f} \t| Val loss: {val_loss:.2f}" +
                f"\t| time: {t:.3f} seconds\n"
            )

            self.logger.log_metrics(
                {
                    "train_loss": self.e_loss[-1],
                    "val_loss": val_loss,
                    "time": t,
                },
                epoch=self.epoch,
                step=self.iteration,
            )

            # early_stopping needs the validation loss to check if it has decreased,
            # and if it has, it will make a checkpoint of the current model
            self.early_stopping(val_loss, self.model)

            if self.early_stopping.early_stop and self.cfg.train_params.early_stopping:
                print(f"{datetime.now():%Y-%m-%d %H:%M:%S} - Epoch {self.epoch:03}, Early stopping")
                self.save()
                break

            if self.epoch % self.cfg.train_params.save_every == 0 or (
                self.e_loss[-1] < self.best
                and self.epoch % self.cfg.train_params.start_saving_best == 0
            ):
                self.save()

            gc.collect()
            self.epoch += 1

        # Update bn statistics for the swa_model at the end
        if self.epoch >= self.cfg.train_params.swa_start:
            # if the first element of sample is the tensor that network should be applied to
            # otherwise, comment the line below, and uncomment the for loop
            # torch.optim.swa_utils.update_bn(self.data, self.swa_model)
            # otherwise, just run a forward pass of every sample in dataset through swa model
            # uncomment the for loop below
            for uid, x, y in self.data:
                x = x.to(device=self.device)
                self.swa_model(x)

            self.save(
                name=self.cfg.directory.model_name + "-final" + str(self.epoch) + "-swa"
            )
        _, x, _ = next(iter(self.data))
        macs, params = op_counter(self.model, sample=x)
        print("macs = ", macs, " | params = ", params)
        self.logger.log_metrics({"GFLOPS": macs[:-1], "#Params": params[:-1]})
        print(f"{datetime.now():%Y-%m-%d %H:%M:%S} - Training is DONE!")

    @timeit
    def forward_batch(self, data):
        """Forward pass of a batch"""
        self.model.train()
        # move data to device
        uuid, x, y = data
        x = x.to(device=self.device)
        y = y.to(device=self.device)

        # forward, backward
        out = self.model(x)
        loss = self.criterion(out, y)
        self.optimizer.zero_grad()
        loss.backward()
        # gradient clipping
        if self.cfg.train_params.grad_clipping > 0:
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(), self.cfg.train_params.grad_clipping
            )
        # check grad norm for debugging
        grad_norm = check_grad_norm(self.model)
        # update
        self.optimizer.step()

        return loss.item(), grad_norm

    @timeit
    @torch.no_grad()
    def validate(self):

        self.model.eval()

        running_loss = []
        bar = tqdm(self.val_data, desc=f"Epoch {self.epoch:03}/{self.cfg.train_params.epochs:03}, validating")
        for uid, x, y in bar:
            # move data to device
            x = x.to(device=self.device)
            y = y.to(device=self.device)

            # forward
            if self.epoch > self.cfg.train_params.swa_start:
                out = self.swa_model(x)
            else:
                out = self.model(x)

            loss = self.criterion(out, y)
            running_loss.append(loss.item())
            bar.set_postfix(loss=loss.item())

        bar.close()

        self.logger.log_image(x[0].squeeze(), f"{out[0].argmax().item()}-|{uid[0]}", step=self.iteration)

        # average loss
        loss = np.mean(running_loss)

        return loss

    def init_model(self):
        """Initializes the model"""
        print(f"{datetime.now():%Y-%m-%d %H:%M:%S} - INITIALIZING the model!")
        model = CustomModel(self.cfg.model)

        if 'cuda' in str(self.device) and self.cfg.train_params.device.split(":")[1] == 'a':
            model = torch.nn.DataParallel(model)

        model.apply(init_weights(**self.cfg.init_model))
        model = model.to(device=self.device)
        return model

    def init_optimizer(self):
        """Initializes the optimizer and learning rate scheduler"""
        print(f"{datetime.now():%Y-%m-%d %H:%M:%S} - INITIALIZING the optimizer!")
        if self.cfg.train_params.optimizer.lower() == "adam":
            optimizer = optim.Adam(self.model.parameters(), **self.cfg.adam)

        elif self.cfg.train_params.optimizer.lower() == "rmsprop":
            optimizer = optim.RMSprop(self.model.parameters(), **self.cfg.rmsprop)

        elif self.cfg.train_params.optimizer.lower() == "sgd":
            optimizer = optim.SGD(self.model.parameters(), **self.cfg.sgd)

        else:
            raise ValueError(f"Unknown optimizer {self.cfg.train_params.optimizer}" + 
                "; valid optimizers are 'adam' and 'rmsprop'.")

        # initialize the learning rate scheduler
        lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.cfg.train_params.epochs
        )
        return optimizer, lr_scheduler

    def init_device(self):
        """Initializes the device"""
        print(f"{datetime.now():%Y-%m-%d %H:%M:%S} - INITIALIZING the device!")
        is_cuda_available = torch.cuda.is_available()
        device = self.cfg.train_params.device

        if 'cpu' in device:
            print(f"Performing all the operations on CPU.")
            return torch.device(device)

        elif 'cuda' in device:
            if is_cuda_available:
                device_idx = device.split(":")[1]
                if device_idx == 'a':
                    print(f"Performing all the operations on CUDA; {torch.cuda.device_count()} devices.")
                    self.cfg.dataloader.batch_size *= torch.cuda.device_count()
                    return torch.device(device.split(":")[0])
                else:
                    print(f"Performing all the operations on CUDA device {device_idx}.")
                    return torch.device(device)
            else:
                print("CUDA device is not available, falling back to CPU!")
                return torch.device('cpu')
        else:
            raise ValueError(f"Unknown {device}!")

    def init_dataloader(self):
        """Initializes the dataloaders"""
        print(f"{datetime.now():%Y-%m-%d %H:%M:%S} - INITIALIZING the train and val dataloaders!")
        dataset = CustomDataset(**self.cfg.dataset)
        data = DataLoader(dataset, **self.cfg.dataloader)
        # creating dataset interface and dataloader for val data
        self.cfg.val_dataset.update(self.cfg.dataset)
        val_dataset = CustomDataset(**self.cfg.val_dataset)

        self.cfg.dataloader.update({'shuffle': False})  # for val dataloader
        val_data = DataLoader(val_dataset, **self.cfg.dataloader)

        # log dataset status
        self.logger.log_parameters(
            {"train_len": len(dataset), "val_len": len(val_dataset)}
        )
        print(f"Training consists of {len(dataset)} samples, and validation consists of {len(val_dataset)} samples.")
        self.logger.log_asset_data(json.dumps(dict(val_dataset.cache_names)), 'val-data-uuid.json')
        self.logger.log_asset_data(json.dumps(dict(dataset.cache_names)), 'train-data-uuid.json')

        return data, val_data

    def if_resume(self):
        if self.cfg.logger.resume:
            # load checkpoint
            print(f"{datetime.now():%Y-%m-%d %H:%M:%S} - LOADING checkpoint!!!")
            save_dir = self.cfg.directory.load
            checkpoint = load_checkpoint(save_dir, self.device)
            self.model.load_state_dict(checkpoint["model"])
            self.optimizer.load_state_dict(checkpoint["optimizer"])
            self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
            self.epoch = checkpoint["epoch"] + 1
            self.e_loss = checkpoint["e_loss"]
            self.iteration = checkpoint["iteration"] + 1
            self.best = checkpoint["best"]
            print(
                f"{datetime.now():%Y-%m-%d %H:%M:%S} " +
                f"LOADING checkpoint was successful, start from epoch {self.epoch}" +
                f" and loss {self.best}"
            )
        else:
            self.epoch = 1
            self.iteration = 0
            self.best = np.inf
            self.e_loss = []

        self.logger.set_epoch(self.epoch)

    def init_logger(self, cfg):
        print(f"{datetime.now():%Y-%m-%d %H:%M:%S} - INITIALIZING the logger!")
        logger = None
        # Check to see if there is a key in environment:
        EXPERIMENT_KEY = cfg.experiment_key

        # First, let's see if we continue or start fresh:
        CONTINUE_RUN = cfg.resume
        if EXPERIMENT_KEY and CONTINUE_RUN:
            # There is one, but the experiment might not exist yet:
            api = comet_ml.API()  # Assumes API key is set in config/env
            try:
                api_experiment = api.get_experiment_by_id(EXPERIMENT_KEY)
            except Exception:
                api_experiment = None
            if api_experiment is not None:
                CONTINUE_RUN = True
                # We can get the last details logged here, if logged:
                # step = int(api_experiment.get_parameters_summary("batch")["valueCurrent"])
                # epoch = int(api_experiment.get_parameters_summary("epochs")["valueCurrent"])

        if CONTINUE_RUN:
            # 1. Recreate the state of ML system before creating experiment
            # otherwise it could try to log params, graph, etc. again
            # ...
            # 2. Setup the existing experiment to carry on:
            logger = comet_ml.ExistingExperiment(
                previous_experiment=EXPERIMENT_KEY,
                log_env_details=True,  # to continue env logging
                log_env_gpu=True,  # to continue GPU logging
                log_env_cpu=True,  # to continue CPU logging
                auto_histogram_weight_logging=True,
                auto_histogram_gradient_logging=True,
                auto_histogram_activation_logging=True,
            )
            # Retrieved from above APIExperiment
            # self.logger.set_epoch(epoch)

        else:
            # 1. Create the experiment first
            #    This will use the COMET_EXPERIMENT_KEY if defined in env.
            #    Otherwise, you could manually set it here. If you don't
            #    set COMET_EXPERIMENT_KEY, the experiment will get a
            #    random key!
            if cfg.online:
                logger = comet_ml.Experiment(
                    disabled=cfg.disabled,
                    project_name=cfg.project,
                    auto_histogram_weight_logging=True,
                    auto_histogram_gradient_logging=True,
                    auto_histogram_activation_logging=True,
                )
                logger.add_tags(cfg.tags.split())
                logger.log_parameters(self.cfg)
            else:
                logger = comet_ml.OfflineExperiment(
                    disabled=cfg.disabled,
                    project_name=cfg.project,
                    offline_directory=cfg.offline_directory,
                    auto_histogram_weight_logging=True,
                )
                logger.set_name(cfg.experiment_name)
                logger.add_tags(cfg.tags.split())
                logger.log_parameters(self.cfg)

        return logger

    def save(self, name=None):
        model = self.model
        if isinstance(self.model, torch.nn.DataParallel):
            model = model.module

        checkpoint = {
            "time": str(datetime.now()),
            "epoch": self.epoch,
            "iteration": self.iteration,
            "model": model.state_dict(),
            "model_name": type(model).__name__,
            "optimizer": self.optimizer.state_dict(),
            "optimizer_name": type(self.optimizer).__name__,
            "lr_scheduler": self.lr_scheduler.state_dict(),
            "best": self.best,
            "e_loss": self.e_loss,
        }

        

        if name is None and self.epoch >= self.cfg.train_params.swa_start:
            save_name = self.cfg.directory.model_name + str(self.epoch) + "-swa"
            checkpoint["model-swa"] = self.swa_model.state_dict()

        elif name is None:
            save_name = self.cfg.directory.model_name + str(self.epoch)

        else:
            save_name = name

        if self.e_loss[-1] < self.best:
            self.best = self.e_loss[-1]
            checkpoint["best"] = self.best
            save_checkpoint(checkpoint, True, self.cfg.directory.save, save_name)
        else:
            save_checkpoint(checkpoint, False, self.cfg.directory.save, save_name)
Ejemplo n.º 19
0
def training(model,
             train_dataloader,
             valid_dataloader,
             test_dataloader,
             model_cfg,
             fold_idx=1):

    print("--------  ", str(fold_idx), "  --------")
    global model_config
    model_config = model_cfg

    device = get_device()
    model.to(device)

    if fold_idx == 1: print('CONFIG: ')
    if fold_idx == 1:
        print([(v, getattr(model_config, v)) for v in dir(model_config)
               if v[:2] != "__"])
    if fold_idx == 1: print('MODEL: ', model)

    epochs = model_config.epochs

    if model_config.optimizer == 'AdamW':
        optimizer = torch.optim.AdamW(model.parameters(),
                                      lr=float(model_config.lr),
                                      eps=float(model_config.eps),
                                      weight_decay=float(
                                          model_config.weight_decay))
    elif model_config.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=float(model_config.lr))

    if model_config.scheduler == 'linear':
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(model_config.warmup_steps),
            num_training_steps=len(train_dataloader) * epochs)
    else:
        scheduler = None

    criterion = nn.BCEWithLogitsLoss()  #nn.CrossEntropyLoss()

    swa_model = AveragedModel(model)
    if model_config.swa_scheduler == 'linear':
        swa_scheduler = SWALR(optimizer, swa_lr=float(model_config.lr))
    else:
        swa_scheduler = CosineAnnealingLR(optimizer, T_max=100)

    print('TRAINING...')

    training_stats = []

    best_dev_auc = float('-inf')

    with tqdm(total=epochs, leave=False) as pbar:
        for epoch_i in range(0, epochs):

            if epoch_i >= int(model_config.swa_start):
                update_bn(train_dataloader, swa_model)
                train_auc, train_acc, avg_train_loss = train(
                    model, train_dataloader, device, criterion, optimizer)
                swa_model.update_parameters(model)
                swa_scheduler.step()
                update_bn(valid_dataloader, swa_model)
                valid_auc, valid_acc, avg_dev_loss, dev_d = valid(
                    swa_model, valid_dataloader, device, criterion)
            else:
                train_auc, train_acc, avg_train_loss = train(
                    model,
                    train_dataloader,
                    device,
                    criterion,
                    optimizer,
                    scheduler=scheduler)
                valid_auc, valid_acc, avg_dev_loss, dev_d = valid(
                    model, valid_dataloader, device, criterion)
            if cfg.final_train:
                valid_auc = 0
                valid_acc = 0
                avg_dev_loss = 0

            add_stats(training_stats, avg_train_loss, avg_dev_loss, train_acc,
                      train_auc, valid_acc, valid_auc)

            if (cfg.final_train &
                (epoch_i == epochs - 1)) | (not cfg.final_train &
                                            (valid_auc > best_dev_auc)):
                best_dev_auc = valid_auc
                if epoch_i >= int(model_config.swa_start):
                    update_bn(test_dataloader, swa_model)
                    test_d = gen_test(swa_model, test_dataloader, device)
                    save(fold_idx, swa_model, optimizer, dev_d, test_d,
                         valid_auc)
                else:
                    test_d = gen_test(model, test_dataloader, device)
                    save(fold_idx, model, optimizer, dev_d, test_d, valid_auc)

            pbar.update(1)

    print('TRAINING COMPLETED')

    # Show training results
    col_names = [
        'train_loss', 'train_acc', 'train_auc', 'dev_loss', 'dev_acc',
        'dev_auc'
    ]
    training_stats = pd.DataFrame(training_stats, columns=col_names)
    print(training_stats.head(epochs))
    plot_training_results(training_stats, fold_idx)

    # If config, get best model and make submission
    if cfg.run['submission'] == True:
        make_submission(model, test_dataloader)
Ejemplo n.º 20
0
def main():

    # data
    mean = 0.1307
    std = 0.3081

    num_epochs = 120
    batch_size = 256
    num_workers = 4
    num_inputs = 28 * 28
    num_classes = 10
    lr = 0.1

    # swa
    swa_lr = 0.01
    swa_start = 100

    # ------------------------ dataset generator----------------------------------

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(mean, ), std=(std, ))
    ])

    mnist_train, mnist_eval = load_dataset(transform)

    # visualize
    # for x, Y in mnist_train:
    #     x = np.transpose(x, (1, 2, 0))
    #     plt.imshow(x)
    #     plt.show()
    #     break

    # dataloader
    train_generator = data.DataLoader(mnist_train,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers)
    eval_generator = data.DataLoader(mnist_eval,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=num_workers)

    #---------------------------------model----------------------------------------

    model = FashionModel(num_classes=num_classes)
    model = model.to(device)

    # loss
    criterion = nn.CrossEntropyLoss()  # contain softmax operation

    # optimizer
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    schedule = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                   lr,
                                                   num_epochs,
                                                   pct_start=.1,
                                                   div_factor=10,
                                                   final_div_factor=10)

    # swa
    swa_model = AveragedModel(model=model)
    swa_scheduler = SWALR(optimizer,
                          anneal_strategy="cos",
                          anneal_epochs=5,
                          swa_lr=swa_lr)

    for epoch in range(num_epochs):
        print('Epoch: {}'.format(epoch))
        train_acc, train_loss = train(model,
                                      train_generator,
                                      criterion,
                                      optimizer=optimizer)
        eval_acc, eval_loss = eval(model, eval_generator, criterion)
        writer.add_scalars('acc', {
            'train': train_acc,
            'eval': eval_acc
        },
                           global_step=epoch)
        writer.add_scalars('loss', {
            'train': train_loss,
            'eval': eval_loss
        },
                           global_step=epoch)

        writer.add_scalar('lr',
                          optimizer.param_groups[0]['lr'],
                          global_step=epoch)

        if epoch > swa_start:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            schedule.step()

    # save model
    save_model(model, model_path)

    # save swa model
    torch.optim.swa_utils.update_bn(train_generator, swa_model, device=device)

    swa_acc, swa_loss = eval(swa_model, eval_generator, criterion)
    print('swa acc:{}, loss{}'.format(swa_acc, swa_loss))
    save_model(swa_model, swa_model_path)
Ejemplo n.º 21
0
                     "Epoch: {}, Training times is {}, Loss is {}, Learning rate is {}"
                     .format(
                         e, trainingTimes, loss.item(),
                         optimizer.state_dict()['param_groups'][0]["lr"]))
                 with torch.no_grad():
                     _, predicted = predict.max(1)
                     total = labelsCuda.size(0)
                     correct = predicted.eq(labelsCuda).sum().item()
                     print(
                         "The Predict is {}, label is {}, correct ratio {},"
                         .format(predict[0:5], labels[0:5],
                                 correct / total + 0.))
             trainingTimes += 1
         if e > swa_start:
             swa_model.update_parameters(model)
             swa_scheduler.step()
         else:
             scheduler.step()
     swa_model.to("cpu")
     torch.optim.swa_utils.update_bn(fishDataLoader, swa_model)
     torch.save(swa_model.state_dict(), checkPoint)
 else:
     print("Test samples are {}".format(testSamples))
     print("No shuffle labels {}".format(intTestLabels))
     newTestLabel = []
     if if_random_labels:
         for i in range(len(uniqueLabels)):
             newTestLabel += [i for _ in range(batch_size)]
         dis = len(intTestLabels) - len(newTestLabel)
         for i in range(dis):
             newTestLabel.append(np.random.randint(0, len(uniqueLabels)))
Ejemplo n.º 22
0
def fine_tune(EPOCHES, BATCH_SIZE, train_image_paths, train_label_paths,
              val_image_paths, val_label_paths, channels, model_path,
              swa_model_path, addNDVI):

    train_loader = get_dataloader(train_image_paths,
                                  train_label_paths,
                                  "train",
                                  addNDVI,
                                  BATCH_SIZE,
                                  shuffle=True,
                                  num_workers=8)
    valid_loader = get_dataloader(val_image_paths,
                                  val_label_paths,
                                  "val",
                                  addNDVI,
                                  BATCH_SIZE,
                                  shuffle=False,
                                  num_workers=8)

    # 定义模型,优化器,损失函数
    # model = smp.UnetPlusPlus(
    #         encoder_name="efficientnet-b7",
    #         encoder_weights="imagenet",
    #         in_channels=channels,
    #         classes=10,
    # )
    model = smp.UnetPlusPlus(
        encoder_name="resnet101",
        encoder_weights="imagenet",
        in_channels=channels,
        classes=10,
    )
    model.to(DEVICE)
    model.load_state_dict(torch.load(model_path))
    # 采用SGD优化器
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=3e-4,
                                weight_decay=1e-3,
                                momentum=0.9)

    # 随机权重平均SWA,实现更好的泛化
    swa_model = AveragedModel(model).to(DEVICE)
    # SWA调整学习率
    swa_scheduler = SWALR(optimizer, swa_lr=1e-5)

    # LovaszLoss是对基于子模块损失凸Lovasz扩展的mIoU损失的直接优化
    loss_fn = LovaszLoss(mode='multiclass').to(DEVICE)

    header = r'Epoch/EpochNum | TrainLoss | ValidmIoU | Time(m)'
    raw_line = r'{:5d}/{:8d} | {:9.3f} | {:9.3f} | {:9.2f}'
    print(header)

    #    # 在训练最开始之前实例化一个GradScaler对象,使用autocast才需要
    #    scaler = GradScaler()

    # 记录当前验证集最优mIoU,以判定是否保存当前模型
    best_miou = 0
    train_loss_epochs, val_mIoU_epochs, lr_epochs = [], [], []
    # 开始训练
    for epoch in range(1, EPOCHES + 1):
        # print("Start training the {}st epoch...".format(epoch))
        # 存储训练集每个batch的loss
        losses = []
        start_time = time.time()
        model.train()
        model.to(DEVICE)
        for batch_index, (image, target) in enumerate(train_loader):
            image, target = image.to(DEVICE), target.to(DEVICE)
            # 在反向传播前要手动将梯度清零
            optimizer.zero_grad()
            #            # 使用autocast半精度加速训练,前向过程(model + loss)开启autocast
            #            with autocast(): #need pytorch>1.6
            # 模型推理得到输出
            output = model(image)
            # 求解该batch的loss
            loss = loss_fn(output, target)
            #                scaler.scale(loss).backward()
            #                scaler.step(optimizer)
            #                scaler.update()
            # 反向传播求解梯度
            loss.backward()
            # 更新权重参数
            optimizer.step()
            losses.append(loss.item())
        swa_model.update_parameters(model)
        swa_scheduler.step()
        # 计算验证集IoU
        val_iou = cal_val_iou(model, valid_loader)
        # 输出验证集每类IoU
        # print('\t'.join(np.stack(val_iou).mean(0).round(3).astype(str)))
        # 保存当前epoch的train_loss.val_mIoU.lr_epochs
        train_loss_epochs.append(np.array(losses).mean())
        val_mIoU_epochs.append(np.mean(val_iou))
        lr_epochs.append(optimizer.param_groups[0]['lr'])
        # 输出进程
        print(raw_line.format(epoch, EPOCHES,
                              np.array(losses).mean(), np.mean(val_iou),
                              (time.time() - start_time) / 60**1),
              end="")
        if best_miou < np.stack(val_iou).mean(0).mean():
            best_miou = np.stack(val_iou).mean(0).mean()
            torch.save(model.state_dict(), model_path[:-4] + "_finetune.pth")
            print("  valid mIoU is improved. the model is saved.")
        else:
            print("")
    # 最后更新BN层参数
    torch.optim.swa_utils.update_bn(train_loader, swa_model, device=DEVICE)
    # 计算验证集IoU
    val_iou = cal_val_iou(model, valid_loader)
    print("swa_model'mIoU is {}".format(np.mean(val_iou)))
    torch.save(swa_model.state_dict(), swa_model_path)
    return train_loss_epochs, val_mIoU_epochs, lr_epochs