Пример #1
0
                'temperature':          temp
            }

            save_model(f'./vae.pt')
            wandb.save('./vae.pt')

            # temperature anneal

            temp = max(temp * math.exp(-ANNEAL_RATE * global_step), TEMP_MIN)

            # lr decay

            sched.step()

        if i % 10 == 0:
            lr = sched.get_last_lr()[0]
            print(epoch, i, f'lr - {lr:6f} loss - {loss.item()}')

            logs = {
                **logs,
                'epoch': epoch,
                'iter': i,
                'loss': loss.item(),
                'lr': lr
            }

        wandb.log(logs)
        global_step += 1

    # save trained model to wandb as an artifact every epoch's end
class PolicyGradient():
    """
    A policy gradient agent implementation
    """
    def __init__(self,
                 state_dim,
                 act_dim,
                 pol_hid_lyrs=[16],
                 val_hid_lyrs=[16],
                 pol_lr=0.001,
                 val_lr=0.005,
                 lr_decay=1,
                 min_pol_lr=0,
                 min_val_lr=0,
                 pol_act=nn.Tanh,
                 val_act=nn.Tanh,
                 discount=1,
                 batch_size=5000,
                 is_discrete=False,
                 baseline=True,
                 seed=None,
                 use_gpu=True):
        # Seed works properly only when you run on cpu
        if seed is not None:
            np.random.seed(seed)
            torch.manual_seed(seed)
        self.is_discrete = is_discrete
        self.state_dim = state_dim
        self.act_dim = act_dim
        self.pol_lr = pol_lr
        self.lr_decay = lr_decay
        self.min_pol_lr = min_pol_lr
        self.discount = discount
        self.batch_size = batch_size
        self.rewards_processed = 0
        self.baseline = baseline
        self.num_trajectories = 0
        self.reqd_trajectories = 10
        self.num_epoch = 0
        # set device for computation
        if torch.cuda.is_available() and use_gpu:
            self.device = torch.device("cuda")
        else:
            self.device = torch.device("cpu")
        # Instantiate policy neural network
        pol_layers = [self.state_dim] + pol_hid_lyrs + [self.act_dim]
        if self.is_discrete:
            self.policy = CategoricalPolicy(pol_layers, pol_act, nn.Identity)
        else:
            self.policy = GaussianPolicy(pol_layers, pol_act, nn.Identity)
        self.policy.to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), self.pol_lr)
        self.policy_lr_sch = ExponentialLR(self.policy_optim, self.lr_decay)
        self.policy_loss = torch.zeros(1,
                                       device=self.device,
                                       dtype=torch.float32)
        # Instantiate neural network for value function
        if self.baseline:
            val_layers = [self.state_dim] + val_hid_lyrs + [1]
            self.val_lr = val_lr
            self.min_val_lr = min_val_lr
            self.value_fn = ValueFunction(val_layers, val_act)
            self.value_fn.to(self.device)
            self.value_optim = Adam(self.value_fn.parameters(), self.val_lr)
            self.value_lr_sch = ExponentialLR(self.value_optim, self.lr_decay)
            self.value_fn_loss = torch.zeros(1,
                                             device=self.device,
                                             dtype=torch.float32)
        # expereience buffers
        self.state_buffer = []
        self.action_buffer = []
        self.reward_buffer = []
        self.discount_buffer = []

    def compute_grad(self, terminal=True):
        """ Calculate the gradient of the policy and value fn """
        raise NotImplementedError

    def start(self, state):
        """
        Start the agent for the episode
        Recieve the agent state and return an action
        """
        self.state_buffer.append(state)
        state_tensor = torch.as_tensor(state,
                                       device=self.device,
                                       dtype=torch.float32)
        action = self.policy.get_action(state_tensor)
        self.action_buffer.append(action)
        self.episode_reward = 0
        self.discount_buffer.append(1)
        return action

    def take_step(self, reward, state):
        """
        Agent stores experience and selects the next action
        Performs policy update if experience buffer is full
        """
        self.reward_buffer.append(reward)
        self.state_buffer.append(state)
        state_tensor = torch.as_tensor(state,
                                       device=self.device,
                                       dtype=torch.float32)
        action = self.policy.get_action(state_tensor)
        self.action_buffer.append(action)
        self.discount_buffer.append(self.discount_buffer[-1] * self.discount)
        self.episode_reward += reward
        return action

    def end(self, reward):
        """
        Agent performs policy update with available experience
        and resets variables for next episode
        """
        self.reward_buffer.append(reward)
        self.episode_reward += reward
        self.num_trajectories += 1
        # Perform policy update
        policy_loss, value_fn_loss = self.compute_grad(True)
        # Empty buffers
        self.state_buffer.clear()
        self.action_buffer.clear()
        self.reward_buffer.clear()
        self.discount_buffer.clear()
        return self.episode_reward, policy_loss, value_fn_loss

    def update_network(self):
        """
        Performs an update of the neural network parameters
        """
        # Update parameters
        if ((self.rewards_processed + len(self.reward_buffer) >=
             self.batch_size)
                or (self.num_trajectories >= self.reqd_trajectories)):
            # Update policy
            # Calculate mean loss
            self.policy_loss = self.policy_loss / (self.rewards_processed +
                                                   len(self.reward_buffer))
            # Backpropagate
            self.policy_loss.backward()
            # Update parameters
            self.policy_optim.step()
            # Schedule learning rate
            if self.policy_lr_sch.get_last_lr()[0] > self.min_pol_lr:
                self.policy_lr_sch.step()
            # Clear gradients
            self.policy_optim.zero_grad()
            # Reset variables
            policy_loss = self.policy_loss.item()
            # Update value function
            if self.baseline:
                # Calculate mean loss
                self.value_fn_loss = \
                    self.value_fn_loss/(self.rewards_processed +
                                        len(self.reward_buffer))
                # Backpropagate
                self.value_fn_loss.backward()
                # Update parameters
                self.value_optim.step()
                # Schedule learning rate
                if self.value_lr_sch.get_last_lr()[0] > self.min_val_lr:
                    self.value_lr_sch.step()
                # Clear gradients
                self.value_optim.zero_grad()
                # Reset variables
                value_fn_loss = self.value_fn_loss.item()

            self.rewards_processed = 0
            self.num_trajectories = 0
            self.num_epoch += 1

            self.policy_loss = torch.zeros(1,
                                           device=self.device,
                                           dtype=torch.float32)
            if self.baseline:
                self.value_fn_loss = torch.zeros(1,
                                                 device=self.device,
                                                 dtype=torch.float32)
                return policy_loss, value_fn_loss
            else:
                return policy_loss, None
        else:
            self.rewards_processed += len(self.reward_buffer)
            return None, None
Пример #3
0
class ThreeStageNetwork():
    def __init__(self,
                 num_classes=101,
                 embedding_size=512,
                 trunk_architecture="efficientnet-b0",
                 trunk_optim="RMSprop",
                 embedder_optim="RMSprop",
                 classifier_optim="RMSprop",
                 trunk_lr=1e-4,
                 embedder_lr=1e-3,
                 classifier_lr=1e-3,
                 weight_decay=1.5e-6,
                 trunk_decay=0.98,
                 embedder_decay=0.93,
                 classifier_decay=0.93,
                 log_train=True,
                 gpu_id=0):
        """
        Inputs:
            num_classes int: Number of Classes (for Classifier purely)
            embedding_size int: The size of embedding space output from Embedder
            trunk_architecture str: To pass to self.get_trunk() either efficientnet-b{i} or resnet-18/50 or mobilenet
            trunk_optim optim: Which optimizer to use, such as adamW
            embedder_optim optim: Which optimizer to use, such as adamW
            classifier_optim optim: Which optimizer to use, such as adamW
            trunk_lr float: The learning rate for the Trunk Optimizer
            embedder_lr float: The learning rate for the Embedder Optimizer
            classifier_lr float: The learning rate for the Classifier Optimizer
            weight_decay float: The weight decay for all 3 optimizers
            trunk_decay float: The multiplier for the Scheduler y_{t+1} <- trunk_decay * y_{t}
            embedder_decay float: The multiplier for the Scheduler y_{t+1} <- embedder_decay * y_{t}
            classifier_decay float: The multiplier for the Scheduler y_{t+1} <- classifier_decay * y_{t}
            log_train Bool: whether or not to save training logs
            gpu_id Int: Only currently used to track the GPU useage
        """

        self.gpu_id = gpu_id
        #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = torch.device(f"cuda")
        self.pretrained = False  # this is used to load the indices for train/val data for now
        self.log_train = log_train

        # build three stage network
        self.num_classes = num_classes
        self.embedding_size = embedding_size
        self.MLP_neurons = 2048  # output size of neural network + size used inside embedder/classifier MLP

        self.get_trunk(trunk_architecture)
        self.trunk = nn.DataParallel(self.trunk.to(self.device))
        self.embedder = nn.DataParallel(
            Network(layer_sizes=[self.MLP_neurons, self.embedding_size],
                    neuron_fc=self.MLP_neurons).to(self.device))
        self.classifier = nn.DataParallel(
            Network(layer_sizes=[self.embedding_size, self.num_classes],
                    neuron_fc=self.MLP_neurons).to(self.device))

        # build optimizers
        self.trunk_optimizer = self.get_optimizer(trunk_optim,
                                                  self.trunk.parameters(),
                                                  lr=trunk_lr,
                                                  weight_decay=weight_decay)
        self.embedder_optimizer = self.get_optimizer(
            embedder_optim,
            self.embedder.parameters(),
            lr=embedder_lr,
            weight_decay=weight_decay)
        self.classifier_optimizer = self.get_optimizer(
            classifier_optim,
            self.classifier.parameters(),
            lr=classifier_lr,
            weight_decay=weight_decay)

        # build schedulers
        self.trunk_scheduler = ExponentialLR(self.trunk_optimizer,
                                             gamma=trunk_decay)
        self.embedder_scheduler = ExponentialLR(self.embedder_optimizer,
                                                gamma=embedder_decay)
        self.classifier_scheduler = ExponentialLR(self.classifier_optimizer,
                                                  gamma=classifier_decay)

        # build pair based losses and the miner
        self.triplet = losses.TripletMarginLoss(margin=0.2).to(self.device)
        self.multisimilarity = losses.MultiSimilarityLoss(alpha=2,
                                                          beta=50,
                                                          base=1).to(
                                                              self.device)
        self.miner = miners.MultiSimilarityMiner(epsilon=0.1)
        # build proxy anchor loss
        self.proxy_anchor = Proxy_Anchor(nb_classes=num_classes,
                                         sz_embed=embedding_size,
                                         mrg=0.2,
                                         alpha=32).to(self.device)
        self.proxy_optimizer = AdamW(self.proxy_anchor.parameters(),
                                     lr=trunk_lr * 10,
                                     weight_decay=1.5E-6)
        self.proxy_scheduler = ExponentialLR(self.proxy_optimizer, gamma=0.8)
        # finally crossentropy loss
        self.crossentropy = torch.nn.CrossEntropyLoss().to(self.device)

        # log some of this information
        self.model_params = {
            "Trunk_Model":
            trunk_architecture,
            "Optimizers": [
                str(self.trunk_optimizer),
                str(self.embedder_optimizer),
                str(self.classifier_optimizer)
            ],
            "Embedder":
            str(self.embedder),
            "Embedding_Dimension":
            str(embedding_size),
            "Weight_Decay":
            weight_decay,
            "Scheduler_Decays":
            [trunk_decay, embedder_decay, classifier_decay],
            "Embedding_Size":
            embedding_size,
            "Learning_Rates": [trunk_lr, embedder_lr, classifier_lr],
            "Miner":
            str(self.miner)
        }

    def get_optimizer(self, optim, params, lr, weight_decay):

        if optim == "adamW":
            return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)
        elif optim == "SGD":
            return torch.optim.SGD(params,
                                   lr=lr,
                                   momentum=0.9,
                                   weight_decay=weight_decay,
                                   nesterov=True)
        elif optim == "RMSprop":
            return torch.optim.RMSprop(params,
                                       lr=lr,
                                       weight_decay=weight_decay)
        else:
            return None

    def get_trunk(self, architecture):

        if "efficientnet" in architecture.lower():
            self.trunk = EfficientNet.from_pretrained(
                architecture, num_classes=self.MLP_neurons)

        elif "resnet" in architecture.lower():
            if "18" in architecture.lower():
                self.trunk = models.resnet18(pretrained=True)
                self.trunk.fc = nn.Linear(512, self.MLP_neurons)

            elif "50" in architecture.lower():
                self.trunk = models.resnext50_32x4d(pretrained=True)
                self.trunk.fc = nn.Linear(2048, self.MLP_neurons)

        elif "mobilenet" in architecture.lower():
            self.trunk = models.mobilenet_v2(pretrained=True)
            self.trunk.classifier[1] = torch.nn.Linear(1280, self.MLP_neurons)

    def get_embeddings_logits(self,
                              dataset,
                              indices,
                              batch_size=128,
                              num_workers=16,
                              return_collisions=False):
        """
        This can be used for inference but is not super appropriate since
        it requires the dataset/indices
        """

        # build a temporary dataloader
        temp_sampler = SubsetRandomSampler(indices)
        temp_loader = DataLoader(dataset=dataset,
                                 sampler=temp_sampler,
                                 batch_size=batch_size,
                                 num_workers=num_workers)
        tot_embeds = []
        tot_logits = []
        tot_labels = []
        accuracies = []

        if return_collisions:
            # initialize weights to count # of collisions
            class_weights = np.zeros(self.num_classes) + 0.2
            label_count = torch.ones(self.num_classes).cuda()

        # turn all models into eval mode
        self.trunk.eval()
        self.embedder.eval()
        self.classifier.eval()

        n_iter = int(temp_loader.sampler.__len__() / batch_size)
        # turn grad off for evaluation
        with torch.no_grad():
            print("Getting Embeddings")
            with tqdm(total=n_iter) as t:
                for i, data in enumerate(temp_loader):
                    im, labels = data

                    # forward pass for each model
                    fc_out = self.trunk(im)
                    embeds = self.embedder(fc_out)
                    logits = self.classifier(embeds)

                    if return_collisions:
                        preds = knn_sim(embeds,
                                        labels,
                                        k=self.M,
                                        distance_weighted=False,
                                        local_normalization=False,
                                        num_classes=self.num_classes)
                        weights = impostor_weights(
                            preds,
                            labels,
                            k=self.M,
                            num_classes=self.num_classes)
                        class_weights += weights.cpu().numpy()
                        label_count = label_count.scatter_add(
                            0, labels,
                            torch.ones(len(labels)).cuda())

                    # embeds -> to cpu and then to array
                    accuracies.append(
                        calc_accuracy(logits, labels.to(self.device)))
                    tot_embeds.append(embeds.cpu().numpy())
                    tot_logits.append(logits.cpu().numpy())
                    tot_labels.append(labels.cpu().numpy())

                    t.update()

        # return the np arrays
        tot_embeds = np.concatenate(tot_embeds, axis=0)
        tot_logits = np.concatenate(tot_logits, axis=0)
        tot_labels = np.concatenate(tot_labels, axis=0)

        print("logits shape", tot_logits.shape)
        print("Accuracy is", np.mean(accuracies))

        del temp_loader, temp_sampler

        if return_collisions:
            return tot_embeds, tot_logits, tot_labels, np.mean(
                accuracies), class_weights / label_count.cpu().numpy()
        else:
            return tot_embeds, tot_logits, tot_labels, np.mean(accuracies)

    def save_all_logits_embeds(self, path):
        """
        This is usually run at the end, it will save all logits/embeddings of the train
        and validation datasets to disk. Warning: Holdout not currently included.
        """

        tembeds, tlogits, tlabels, _ = self.get_embeddings_logits(
            self.val_dataset,
            self.train_indices,
            batch_size=self.batch_size * 4,
            num_workers=self.num_workers)
        vembeds, vlogits, vlabels, _ = self.get_embeddings_logits(
            self.val_dataset,
            self.val_indices,
            batch_size=self.batch_size * 4,
            num_workers=self.num_workers)

        np.savez(path,
                 tembeds=tembeds,
                 tlogits=tlogits,
                 tlabels=tlabels,
                 vembeds=vembeds,
                 vlogits=vlogits,
                 vlabels=vlabels)

    def image_inference(self, image):
        # image should be an nparray

        self.trunk.eval()
        self.embedder.eval()
        self.classifier.eval()

        with torch.no_grad():
            fc_out = self.trunk(image.to(self.device))
            embeds = self.embedder(fc_out)
            logits = self.classifier(embeds)

        return embeds, logits

    def save_model(self, path):
        """
        This function is used to save the state dictionaries of
        all three models and their corresponding classifiers to
        the input path provided under the name "models.h5"
        """

        print("Saving model to", path)
        torch.save(
            {
                "trunk_state_dict":
                self.trunk.state_dict(),
                "embedder_state_dict":
                self.embedder.state_dict(),
                "classifier_state_dict":
                self.classifier.state_dict(),
                "trunk_optimizer_state_dict":
                self.trunk_optimizer.state_dict(),
                "embedder_optimizer_state_dict":
                self.embedder_optimizer.state_dict(),
                "classifier_optimizer_state_dict":
                self.classifier_optimizer.state_dict(),
            }, path + "/models.h5")

    def load_weights(self,
                     path,
                     load_trunk=True,
                     load_embedder=True,
                     load_classifier=True,
                     partial_classifier=False,
                     load_optimizers=True):
        """
        This function is to continue training or to use a pretrained model,
        it will load a file saved from the save_model() method above at the
        given path. It will also set the pretrained flag to True.
        """

        weights = torch.load(path)

        self.pretrained = True
        loaded = []

        if load_trunk:
            self.trunk.load_state_dict(weights["trunk_state_dict"])
            if load_optimizers:
                self.trunk_optimizer.load_state_dict(
                    weights["trunk_optimizer_state_dict"])
            loaded.append("Trunk")

        if load_embedder:
            self.embedder.load_state_dict(weights["embedder_state_dict"])
            if load_optimizers:
                self.embedder_optimizer.load_state_dict(
                    weights["embedder_optimizer_state_dict"])
            loaded.append("Embedder")

        if load_classifier:
            self.classifier.load_state_dict(weights["classifier_state_dict"])
            if load_optimizers:
                self.classifier_optimizer.load_state_dict(
                    weights["classifier_optimizer_state_dict"])
            loaded.append("Classifier")

        if partial_classifier:
            print("Partial Loading of classifier")
            # load overall network
            self.classifier = nn.DataParallel(
                Network(layer_sizes=[self.embedding_size, 693],
                        neuron_fc=self.MLP_neurons).to(self.device))
            self.classifier.load_state_dict(weights["classifier_state_dict"])
            # replace last layer with number of classes
            self.classifier.module.classifier[4] = nn.Linear(
                self.MLP_neurons, self.num_classes).to(self.device)

        print("Loaded pretrained weights for", path, "for", loaded)

    def augmented_embeds(self, image, N):
        # TEMPORARY METHOD WILL BE REMOVED AS IT ISN'T USEFUL RIGHT NOW
        """
        Pass PIL Image and get back N augmented versions of the image,
        as well as the original
        """

        # setup the augmentation, this code should change, pretty ugly!
        transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=(0.8, 1.3),
                                   contrast=(0.8, 1.2),
                                   saturation=(0.9, 1.2),
                                   hue=(-0.05, 0.05)),
            transforms.RandomRotation(degrees=(-5, 5), expand=True),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        val_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        # augment the image N times and return
        self.trunk.eval()
        self.embedder.eval()

        tot_embeds = []
        with torch.no_grad():

            # get original image with val transform
            fc_out = self.trunk(
                val_transform(image).to(self.device).reshape(1, 3, 224, 224))
            embeds = self.embedder(fc_out)
            tot_embeds.append(embeds.squeeze())

            # now get N augmented images
            for _ in range(N):
                fc_out = self.trunk(
                    transform(image).to(self.device).reshape(1, 3, 224, 224))
                embeds = self.embedder(fc_out)
                tot_embeds.append(embeds.squeeze())

        return torch.stack(tot_embeds)

    def setup_data(self,
                   dataset,
                   h5_path=None,
                   batch_size=128,
                   num_workers=16,
                   M=3,
                   train_split=0.8,
                   labels=None,
                   repeat_indices=1,
                   train_labels=None,
                   load_indices=False,
                   indices_path=None,
                   max_batches=None,
                   log_save_path="logs"):
        """
        This method is meant to be used prior to training in order
        to get the appropriate dataloaders, datasets, indices etc...
        I'm not sure if the need for this method suggests bad design?

        Inputs:
            path str: the hdf5 dataset  path for training
            batch_size int: the batch size used for training later
            num_workers int: the number of workers to pass to training dataloader
            labels None or array: the array of all labels
            train_labels None or array: the labels to train on, if none will train on all
            load_indices Bool: whether to load train/val/holdout indices, usually used if
                               you are continuing to train a pretrained model. Careful as
                               incorrect loading of indices can result in test set leakage.
            indices_path str: Where the indices you wish to load are located
            log_save_path str: which directory to save training logs to, such as batch_history.csv
        """

        if labels is None and h5_path is not None:
            self.labels = h5py.File(h5_path, "r")["labels"][:]
        else:
            self.labels = labels

        if load_indices:
            arr = np.load(indices_path)
            train_indices, val_indices, holdout_indices = arr["train"], arr[
                "val"], arr["holdout"]
        else:
            if self.pretrained is True:
                warnings.warn(
                    "Picking random indices, dangerous if you are continuing training!"
                )
            train_indices, val_indices, holdout_indices = get_train_val_holdout_indices(
                labels=self.labels,
                train_labels=train_labels,
                train_split=train_split)
            if self.log_train:
                np.savez(log_save_path + "/data_indices.npz",
                         train=train_indices,
                         val=val_indices,
                         holdout=holdout_indices)

        self.augmentations = data_augmentation(hflip=True,
                                               crop=False,
                                               colorjitter=True,
                                               rotations=False,
                                               affine=False,
                                               imagenet=True)

        trainloader, val_dataset = get_dataloaders(
            dataset=dataset,
            h5_path=h5_path,
            batch_size=batch_size,
            num_workers=num_workers,
            augmentations=self.augmentations,
            M=M,
            labels=self.labels,
            train_indices=np.repeat(train_indices, repeat_indices),
            max_batches=max_batches)

        self.max_batches = max_batches
        self.dataset = dataset
        self.trainloader = trainloader
        self.val_dataset = val_dataset
        self.train_indices = train_indices
        self.val_indices = val_indices
        self.holdout_indices = holdout_indices
        self.batch_size = batch_size
        self.log_save_path = log_save_path
        self.num_workers = num_workers
        self.M = M
        self.repeat_indices = repeat_indices

    def train(self,
              n_epochs,
              loss_ratios=[1, 1, 1, 3],
              class_weighting=False,
              model_save_path="models",
              model_name="models.h5",
              epoch_train=False,
              epoch_val=True,
              epoch_save=False,
              save_trunk=True,
              save_embedder=True,
              save_classifier=True,
              train_trunk=True):
        """
        This method is used for actually training the model, it is meant
        to be called after the setup_data() method and won't function
        properly without it as it will not have access to some attributes.

        Inputs:
            n_epochs int: the amount of epochs to train for
            loss_ratios array: the ratios to pass to triplet, multisimilarity,
                               proxy anchor and crossentropy in that order
            model_save_path str: where to save models at checkpoints and at end
            log_train Bool: whether to log the train files
        """

        self.model_params["Loss_Ratios"] = loss_ratios

        # Set up the GPU handle to report useage/vram useage
        nvmlInit()
        handle = nvmlDeviceGetHandleByIndex(self.gpu_id)

        # Set up the logging
        self.batch_history = {
            "Iteration": [],
            "Loss": [],
            "Losses": [],
            "Accuracy": [],
            "GPU_useage": [],
            "GPU_mem": [],
            "Time": []
        }
        self.epoch_history = {
            "Epoch": [],
            "Train_Accuracy": [],
            "Val_Accuracy": [],
            "Learning_Rates": [],
            "Time": []
        }
        batch_log_path = self.log_save_path + "/batch_history.csv"
        epoch_log_path = self.log_save_path + "/epoch_history.csv"

        if self.log_train is True and os.path.exists(batch_log_path) is False:
            with open(batch_log_path, "a+") as f:
                writer = csv.writer(f)
                writer.writerow(list(self.batch_history.keys()))
            with open(epoch_log_path, "a+") as f:
                writer = csv.writer(f)
                writer.writerow(list(self.epoch_history.keys()))

        # This is purely used for model checkpoints to save the best epoch model
        best_val_accuracy = 0

        # setup AMP GradScaler
        scaler = torch.cuda.amp.GradScaler()

        print(f"Starting training with {n_epochs} Epochs.")
        n_iters = np.int(self.trainloader.sampler.__len__() / self.batch_size)
        for epoch in range(n_epochs):
            # set our models to train mode
            if train_trunk:
                self.trunk.train()
            else:
                self.trunk.eval()
            self.embedder.train()
            self.classifier.train()

            # initialize our batch accuracy and loss parameters that are later used
            # to compute a rolling mean.
            batch_acc = 0
            batch_loss = 0
            batch_acc_queue = []
            batch_loss_queue = []

            performance_dict = {
                "Load_Data": 0,
                "Forward_Pass": 0,
                "Mining": 0,
                "Compute_Loss": 0,
                "Optim_Step": 0,
                "Logging": 0,
                "Total": 0
            }

            # initialize class weights to be uniform distribution if no class_weighting
            if class_weighting:
                # start in a smoothed fashion with 5 collisions each
                class_weights = np.zeros(self.num_classes) + 0.2
            else:
                class_weights = np.zeros(
                    self.num_classes) + 0.2  #/ self.num_classes
            label_count = torch.ones(self.num_classes).cuda()

            with tqdm(total=int(n_iters)) as t:
                start_t = time()
                for i, data in enumerate(self.trainloader):
                    inputs, labels = data
                    inputs, labels = inputs.to(self.device), labels.to(
                        self.device)

                    # zero the parameter gradients
                    if train_trunk:
                        self.trunk_optimizer.zero_grad()
                        self.trunk.zero_grad()

                    self.embedder_optimizer.zero_grad()
                    self.embedder.zero_grad()
                    self.classifier_optimizer.zero_grad()
                    self.classifier.zero_grad()
                    self.proxy_optimizer.zero_grad()

                    # forward pass
                    with autocast():
                        time_check = time()
                        fc_out = self.trunk(inputs)
                        embeddings = self.embedder(fc_out)
                        logits = self.classifier(embeddings)
                        performance_dict["Forward_Pass"] += time() - time_check

                        # mine interesting pairs
                        time_check = time()
                        if loss_ratios[0] + loss_ratios[1] != 0:
                            hard_pairs = self.miner(embeddings, labels)
                        performance_dict["Mining"] += time() - time_check

                        # compute loss, the conditionals are to speed up compute if a loss
                        # has been switched off.
                        time_check = time()
                        loss = 0
                        curr_losses = []
                        if loss_ratios[0] != 0:
                            triplet_loss_curr = self.triplet(
                                embeddings, labels, hard_pairs)
                            curr_losses.append(triplet_loss_curr.item() *
                                               loss_ratios[0])
                            loss += triplet_loss_curr * loss_ratios[0]

                        if loss_ratios[1] != 0:
                            ms_loss_curr = self.multisimilarity(
                                embeddings, labels, hard_pairs)
                            curr_losses.append(ms_loss_curr.item() *
                                               loss_ratios[1])
                            loss += ms_loss_curr * loss_ratios[1]

                        if loss_ratios[2] != 0:
                            proxy_loss_curr = self.proxy_anchor(
                                embeddings, labels)
                            curr_losses.append(proxy_loss_curr.item() *
                                               loss_ratios[2])
                            loss += proxy_loss_curr * loss_ratios[2]

                        if loss_ratios[3] != 0:
                            cse_loss_curr = self.crossentropy(
                                logits,
                                labels.to(self.device).long())
                            curr_losses.append(cse_loss_curr.item() *
                                               loss_ratios[3])
                            loss += cse_loss_curr * loss_ratios[3]

                    scaler.scale(loss).backward()
                    performance_dict["Compute_Loss"] += time() - time_check

                    # now take a step
                    time_check = time()
                    if train_trunk:
                        scaler.step(self.trunk_optimizer)
                    scaler.step(self.embedder_optimizer)
                    scaler.step(self.classifier_optimizer)
                    scaler.step(self.proxy_optimizer)

                    scaler.update()

                    #if class_weighting:
                    # compute batch label weightings
                    k = 1  #self.M
                    preds = knn_sim(embeddings,
                                    labels,
                                    k=k,
                                    distance_weighted=False,
                                    local_normalization=False,
                                    num_classes=self.num_classes)
                    weights = impostor_weights(preds,
                                               labels,
                                               k=k,
                                               num_classes=self.num_classes)
                    #weights, associated_labels = get_weights(preds, labels)
                    # moving average weight calculation (x[l] = x[l] + (new_data - x[l])/(i+1))
                    class_weights += weights.cpu().numpy()
                    label_count = label_count.scatter_add(
                        0, labels,
                        torch.ones(len(labels)).cuda())
                    #class_weights += (weights.cpu().numpy() - class_weights)/(i+1)
                    #class_weights[associated_labels] += (weights - class_weights[associated_labels])/(i+1)

                    performance_dict["Optim_Step"] += time() - time_check

                    time_check = time()
                    # compute mean using queue datastructure of length 2048//batch_size.
                    batch_acc_queue.append(
                        calc_accuracy(logits, labels.to(self.device)))
                    batch_loss_queue.append(loss.item())
                    if len(batch_acc_queue) >= 2048 // self.batch_size:
                        batch_acc_queue.pop(0)
                        batch_loss_queue.pop(0)
                    batch_acc = np.mean(batch_acc_queue)
                    batch_loss = np.mean(batch_loss_queue)

                    res = nvmlDeviceGetUtilizationRates(handle)
                    # log the current batch information
                    if self.log_train:
                        self.batch_history["Iteration"].append(epoch *
                                                               n_iters + i)
                        self.batch_history["Loss"].append(batch_loss)
                        self.batch_history["Accuracy"].append(batch_acc)
                        self.batch_history["Time"].append(
                            datetime.now().strftime("%d/%m/%Y %H:%M:%S"))
                        self.batch_history["GPU_useage"].append(res.gpu)
                        self.batch_history["GPU_mem"].append(res.memory)

                        # write to CSV file, should change this to dict writer soon
                        with open(batch_log_path, "a") as f:
                            writer = csv.writer(f)
                            writer.writerow([
                                epoch * n_iters + i, batch_loss, batch_acc,
                                res.gpu, res.memory,
                                datetime.now().strftime("%d/%m/%Y %H:%M:%S")
                            ])

                    # now update our loading bar with new values of batch loss and accuracy
                    t.set_description('Epoch %i' % int(epoch))
                    t.set_postfix(loss=batch_loss,
                                  acc=batch_acc,
                                  gpu=res.gpu,
                                  gpuram=res.memory,
                                  losses=[np.round(i, 2) for i in curr_losses])
                    t.update()
                    performance_dict["Logging"] += time() - time_check

            # save class weights after each epoch
            #if class_weighting:
            class_weights /= label_count.cpu().numpy(
            )  # normalize by amount of times a certain label has occured
            np.save(f"logs/class_weights_{epoch}.npy", class_weights)

            # build and save performance dictionary, keep in mind Load_Data is inaccurate due to prefetching
            performance_dict["Load_Data"] = np.sum(performance_dict[key]
                                                   for key in performance_dict)
            performance_dict["Total"] = time() - start_t
            performance_dict["Load_Data"] = performance_dict[
                "Total"] - performance_dict["Load_Data"]
            performance_dict = {
                key: np.round(performance_dict[key], 2)
                for key in performance_dict
            }

            if self.log_train:
                with open(self.log_save_path + "/performance.json",
                          "w") as json_file:
                    json.dump(performance_dict, json_file)
            print(performance_dict)

            if class_weighting is True:
                # get the next dataloader based on class weights
                del self.trainloader
                self.trainloader, self.val_dataset = get_dataloaders(
                    dataset=self.dataset,
                    batch_size=self.batch_size,
                    num_workers=self.num_workers,
                    augmentations=self.augmentations,
                    M=self.M,
                    labels=self.labels,
                    train_indices=np.repeat(self.train_indices,
                                            self.repeat_indices),
                    class_weights=class_weights,
                    max_batches=self.max_batches)

            if epoch_train is True:
                # Train accuracy, embeddings and potential UMAP
                print("Training")
                em, lo, la, train_accuracy, collisions = self.get_embeddings_logits(
                    self.val_dataset, self.train_indices, self.batch_size * 2,
                    self.num_workers, True)
                np.save(f"logs/class_weights_{epoch}_val.npy", collision)

            if epoch_val is True:
                # Validation accuracy, loss, embeddings and potential UMAP
                print("Validation")
                em, lo, la, val_accuracy = self.get_embeddings_logits(
                    self.val_dataset, self.val_indices, self.batch_size * 2,
                    self.num_workers)
            else:
                # sadly the way its written we have to increment this or it won't
                # save the model since it won't be greater than last epochs.
                val_accuracy = 0.001

            # finally we log the batch metrics
            if self.log_train:
                self.epoch_history["Learning_Rates"].append([
                    self.trunk_scheduler.get_last_lr(),
                    self.embedder_scheduler.get_last_lr(),
                    self.classifier_scheduler.get_last_lr()
                ])
                self.epoch_history["Epoch"].append(epoch)
                if epoch_train:
                    self.epoch_history["Train_Accuracy"].append(train_accuracy)
                else:
                    train_accuracy = 0
                if epoch_val:
                    self.epoch_history["Val_Accuracy"].append(val_accuracy)

                self.epoch_history["Time"].append(
                    datetime.now().strftime("%d/%m/%Y %H:%M:%S"))

                # write CSV file
                with open(epoch_log_path, "a") as f:
                    writer = csv.writer(f)
                    writer.writerow([
                        epoch, train_accuracy, val_accuracy,
                        [
                            self.trunk_scheduler.get_last_lr(),
                            self.embedder_scheduler.get_last_lr(),
                            self.classifier_scheduler.get_last_lr()
                        ],
                        datetime.now().strftime("%d/%m/%Y %H:%M:%S")
                    ])

            # check the learning rate schedulers
            if train_trunk:
                self.trunk_scheduler.step()
            self.embedder_scheduler.step()
            self.classifier_scheduler.step()
            self.proxy_scheduler.step()

            # save best model (based on validation accuracy)
            if val_accuracy >= best_val_accuracy or epoch_save is True:
                best_val_accuracy = val_accuracy
                # WARNING!!!! This MIGHT not work if parallel GPUs are used, then would
                # need to use model.module.state_dict() I believe? Not sure!
                save_dict = {}
                if save_trunk:
                    save_dict["trunk_state_dict"] = self.trunk.state_dict()
                    save_dict[
                        "trunk_optimizer_state_dict"] = self.trunk_optimizer.state_dict(
                        )
                if save_embedder:
                    save_dict[
                        "embedder_state_dict"] = self.embedder.state_dict()
                    save_dict[
                        "embedder_optimizer_state_dict"] = self.embedder_optimizer.state_dict(
                        )
                if save_classifier:
                    save_dict[
                        "classifier_state_dict"] = self.classifier.state_dict(
                        )
                    save_dict[
                        "classifier_optimizer_state_dict"] = self.classifier_optimizer.state_dict(
                        )

                torch.save(save_dict, model_save_path + "/" + model_name)

                # save the JSON including model details, this can be improved to take the mean
                self.model_params["final_val_accuracy"] = best_val_accuracy
                self.model_params["Time"] = datetime.now().strftime(
                    "%d/%m/%Y %H:%M:%S")
                self.model_params = {
                    k: str(self.model_params[k])
                    for k in self.model_params
                }
                if self.log_train:
                    with open(self.log_save_path + "/model_dict.json",
                              "w") as json_file:
                        json.dump(self.model_params, json_file)
Пример #4
0
def train(model: nn.Module, train_dataloader: torch.utils.data.DataLoader,
          val_dataloader: torch.utils.data.DataLoader):
    if ModelConfig.N_TO_N:
        loss_fn = CE_Loss()
    else:
        loss_fn = nn.CrossEntropyLoss()

    on_epoch_begin = model.reset_lstm_state if model.__class__.__name__ == "LRCN" else None

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=ModelConfig.LR,
                                 weight_decay=ModelConfig.REG_FACTOR)
    trainer = Trainer(model,
                      loss_fn,
                      train_dataloader,
                      val_dataloader,
                      ModelConfig.BATCH_SIZE,
                      optimizer=optimizer,
                      on_epoch_begin=on_epoch_begin)
    scheduler = ExponentialLR(optimizer, gamma=ModelConfig.LR_DECAY)
    if DataConfig.USE_TB:
        metrics = Metrics(model,
                          loss_fn,
                          train_dataloader,
                          val_dataloader,
                          DataConfig.LABEL_MAP,
                          n_to_n=ModelConfig.N_TO_N,
                          max_batches=None)
        tensorboard = TensorBoard(
            model,
            metrics,
            DataConfig.LABEL_MAP,
            DataConfig.TB_DIR,
            ModelConfig.GRAYSCALE,
            ModelConfig.IMAGE_SIZES,
            ModelConfig.BATCH_SIZE,
            ModelConfig.N_TO_N,
            write_graph=model.__class__.__name__ != "LRCN",
            sequence_length=ModelConfig.SEQUENCE_LENGTH)

    best_loss = 1000
    last_checkpoint_epoch = 0
    train_start_time = time.time()

    preprocess_fn = getattr(model, "preprocess", None)
    if not callable(preprocess_fn):
        preprocess_fn = None
    postprocess_fn = getattr(model, "postprocess", None)
    if not callable(postprocess_fn):
        postprocess_fn = None

    try:
        for epoch in range(ModelConfig.MAX_EPOCHS):
            epoch_start_time = time.perf_counter()
            print(f"\nEpoch {epoch}/{ModelConfig.MAX_EPOCHS}")

            epoch_loss = trainer.train_epoch()
            if DataConfig.USE_TB:
                tensorboard.write_loss(epoch, epoch_loss)
                tensorboard.write_lr(epoch, scheduler.get_last_lr()[0])

            if (epoch_loss < best_loss and DataConfig.USE_CHECKPOINT
                    and epoch >= DataConfig.RECORD_START
                    and (epoch - last_checkpoint_epoch) >=
                    DataConfig.CHECKPT_SAVE_FREQ):
                save_path = os.path.join(DataConfig.CHECKPOINT_DIR,
                                         f"train_{epoch}.pt")
                print(
                    f"\nLoss improved from {best_loss:.5e} to {epoch_loss:.5e},"
                    f"saving model to {save_path}",
                    end='\r')
                best_loss, last_checkpoint_epoch = epoch_loss, epoch
                torch.save(model.state_dict(), save_path)

            print(
                f"\nEpoch loss: {epoch_loss:.5e}  -  Took {time.perf_counter() - epoch_start_time:.5f}s"
            )

            # Validation and other metrics
            if epoch % DataConfig.VAL_FREQ == 0 and epoch >= DataConfig.RECORD_START:
                with torch.no_grad():
                    validation_start_time = time.perf_counter()
                    epoch_loss = trainer.val_epoch()

                    if DataConfig.USE_TB:
                        print("\nStarting to compute TensorBoard metrics",
                              end="\r",
                              flush=True)
                        # TODO: Uncomment line bellow and see if it works properly
                        # tensorboard.write_weights_grad(epoch)
                        tensorboard.write_loss(epoch,
                                               epoch_loss,
                                               mode="Validation")

                        # Metrics for the Train dataset
                        tensorboard.write_images(epoch,
                                                 train_dataloader,
                                                 input_is_video=True,
                                                 preprocess_fn=preprocess_fn,
                                                 postprocess_fn=postprocess_fn)
                        if epoch % (3 * DataConfig.VAL_FREQ) == 0:
                            tensorboard.write_videos(
                                epoch,
                                train_dataloader,
                                preprocess_fn=preprocess_fn,
                                postprocess_fn=postprocess_fn)
                        train_acc = tensorboard.write_metrics(
                            epoch, write_defect_acc=True)

                        # Metrics for the Validation dataset
                        tensorboard.write_images(epoch,
                                                 val_dataloader,
                                                 mode="Validation",
                                                 input_is_video=True,
                                                 preprocess_fn=preprocess_fn,
                                                 postprocess_fn=postprocess_fn)
                        if epoch % (3 * DataConfig.VAL_FREQ) == 0:
                            tensorboard.write_videos(
                                epoch,
                                val_dataloader,
                                mode="Validation",
                                preprocess_fn=preprocess_fn,
                                postprocess_fn=postprocess_fn)
                        val_acc = tensorboard.write_metrics(
                            epoch, mode="Validation", write_defect_acc=True)

                        print(
                            f"Train accuracy: {train_acc:.3f}  -  Validation accuracy: {val_acc:.3f}",
                            end='\r',
                            flush=True)

                    print(
                        f"\nValidation loss: {epoch_loss:.5e}  -"
                        f"  Took {time.perf_counter() - validation_start_time:.5f}s",
                        flush=True)
            scheduler.step()
    except KeyboardInterrupt:
        print("\n")

    train_stop_time = time.time()
    tensorboard.close_writers()
    memory_peak, gpu_memory = resource_usage()
    print(
        "Finished Training"
        f"\n\tTraining time : {train_stop_time - train_start_time:.03f}s"
        f"\n\tRAM peak : {memory_peak // 1024} MB\n\tVRAM usage : {gpu_memory}"
    )
Пример #5
0
def main():
    parser = get_parser()
    args = parser.parse_args()

    if not torch.cuda.is_available():
        raise ValueError(
            "The script requires CUDA support, but CUDA not available")

    args.rank = -1
    args.world_size = 1

    if args.model_parallel:
        args.deepspeed = False
        cfg = {
            "microbatches": args.num_microbatches,
            "placement_strategy": args.placement_strategy,
            "pipeline": args.pipeline,
            "optimize": args.optimize,
            "partitions": args.num_partitions,
            "horovod": args.horovod,
            "ddp": args.ddp,
        }

        smp.init(cfg)
        torch.cuda.set_device(smp.local_rank())
        args.rank = smp.dp_rank()
        args.world_size = smp.size()
    else:
        # initialize deepspeed
        print(f"args.deepspeed : {args.deepspeed}")
        deepspeed_utils.init_deepspeed(args.deepspeed)
        if deepspeed_utils.is_root_worker():
            args.rank = 0

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed + args.rank)
        np.random.seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)

    # args.LEARNING_RATE = args.LEARNING_RATE * float(args.world_size)

    cudnn.deterministic = True

    if cudnn.deterministic:
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    args.kwargs = {'num_workers': args.num_worker, 'pin_memory': True}

    device = torch.device("cuda")

    logger.debug(f"args.image_folder : {args.image_folder}")
    logger.debug(f"args.rank : {args.rank}")

    ## SageMaker
    try:
        if os.environ.get('SM_MODEL_DIR') is not None:
            args.model_dir = os.environ.get('SM_MODEL_DIR')
            #             args.output_dir = os.environ.get('SM_OUTPUT_DATA_DIR')
            args.image_folder = os.environ.get('SM_CHANNEL_TRAINING')
    except:
        logger.debug("not SageMaker")
        pass

    IMAGE_SIZE = args.image_size
    IMAGE_PATH = args.image_folder

    EPOCHS = args.EPOCHS
    BATCH_SIZE = args.BATCH_SIZE
    LEARNING_RATE = args.LEARNING_RATE
    LR_DECAY_RATE = args.LR_DECAY_RATE

    NUM_TOKENS = args.NUM_TOKENS
    NUM_LAYERS = args.NUM_LAYERS
    NUM_RESNET_BLOCKS = args.NUM_RESNET_BLOCKS
    SMOOTH_L1_LOSS = args.SMOOTH_L1_LOSS
    EMB_DIM = args.EMB_DIM
    HID_DIM = args.HID_DIM
    KL_LOSS_WEIGHT = args.KL_LOSS_WEIGHT

    STARTING_TEMP = args.STARTING_TEMP
    TEMP_MIN = args.TEMP_MIN
    ANNEAL_RATE = args.ANNEAL_RATE

    NUM_IMAGES_SAVE = args.NUM_IMAGES_SAVE

    #     transform = Compose(
    #         [
    #             RandomResizedCrop(args.image_size, args.image_size),
    #             OneOf(
    #                 [
    #                     IAAAdditiveGaussianNoise(),
    #                     GaussNoise(),
    #                 ],
    #                 p=0.2
    #             ),
    #             VerticalFlip(p=0.5),
    #             OneOf(
    #                 [
    #                     MotionBlur(p=.2),
    #                     MedianBlur(blur_limit=3, p=0.1),
    #                     Blur(blur_limit=3, p=0.1),
    #                 ],
    #                 p=0.2
    #             ),
    #             OneOf(
    #                 [
    #                     CLAHE(clip_limit=2),
    #                     IAASharpen(),
    #                     IAAEmboss(),
    #                     RandomBrightnessContrast(),
    #                 ],
    #                 p=0.3
    #             ),
    #             HueSaturationValue(p=0.3),
    # #             Normalize(
    # #                 mean=[0.485, 0.456, 0.406],
    # #                 std=[0.229, 0.224, 0.225],
    # #             )
    #         ],
    #         p=1.0
    #     )

    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize(IMAGE_SIZE),
        T.CenterCrop(IMAGE_SIZE),
        T.ToTensor()
    ])

    sampler = None
    dl = None

    # data
    logger.debug(f"IMAGE_PATH : {IMAGE_PATH}")
    #     ds = AlbumentationImageDataset(
    #         IMAGE_PATH,
    #         transform=transform,
    #         args=args
    #     )
    ds = ImageFolder(
        IMAGE_PATH,
        transform=transform,
    )

    if args.model_parallel and (args.ddp
                                or args.horovod) and smp.dp_size() > 1:
        partitions_dict = {
            f"{i}": 1 / smp.dp_size()
            for i in range(smp.dp_size())
        }
        ds = SplitDataset(ds, partitions=partitions_dict)
        ds.select(f"{smp.dp_rank()}")

    dl = DataLoader(ds,
                    BATCH_SIZE,
                    shuffle=True,
                    drop_last=args.model_parallel,
                    **args.kwargs)

    vae_params = dict(image_size=IMAGE_SIZE,
                      num_layers=NUM_LAYERS,
                      num_tokens=NUM_TOKENS,
                      codebook_dim=EMB_DIM,
                      hidden_dim=HID_DIM,
                      num_resnet_blocks=NUM_RESNET_BLOCKS)

    vae = DiscreteVAE(**vae_params,
                      smooth_l1_loss=SMOOTH_L1_LOSS,
                      kl_div_loss_weight=KL_LOSS_WEIGHT).to(device)
    # optimizer

    opt = Adam(vae.parameters(), lr=LEARNING_RATE)
    sched = ExponentialLR(optimizer=opt, gamma=LR_DECAY_RATE)

    if args.model_parallel:
        import copy
        dummy_codebook = copy.deepcopy(vae.codebook)
        dummy_decoder = copy.deepcopy(vae.decoder)

        vae = smp.DistributedModel(vae)
        scaler = smp.amp.GradScaler()
        opt = smp.DistributedOptimizer(opt)

        if args.partial_checkpoint:
            args.checkpoint = smp.load(args.partial_checkpoint, partial=True)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])
        elif args.full_checkpoint:
            args.checkpoint = smp.load(args.full_checkpoint, partial=False)
            vae.load_state_dict(args.checkpoint["model_state_dict"])
            opt.load_state_dict(args.checkpoint["optimizer_state_dict"])

    assert len(ds) > 0, 'folder does not contain any images'

    if (not args.model_parallel) and args.rank == 0:
        print(f'{len(ds)} images found for training')

        # weights & biases experiment tracking

        #         import wandb

        model_config = dict(num_tokens=NUM_TOKENS,
                            smooth_l1_loss=SMOOTH_L1_LOSS,
                            num_resnet_blocks=NUM_RESNET_BLOCKS,
                            kl_loss_weight=KL_LOSS_WEIGHT)

#         run = wandb.init(
#             project = 'dalle_train_vae',
#             job_type = 'train_model',
#             config = model_config
#         )

    def save_model(path):
        if not args.rank == 0:
            return

        save_obj = {'hparams': vae_params, 'weights': vae.state_dict()}

        torch.save(save_obj, path)

    # distribute with deepspeed
    if not args.model_parallel:
        deepspeed_utils.check_batch_size(BATCH_SIZE)
        deepspeed_config = {'train_batch_size': BATCH_SIZE}

        (distr_vae, opt, dl, sched) = deepspeed_utils.maybe_distribute(
            args=args,
            model=vae,
            optimizer=opt,
            model_parameters=vae.parameters(),
            training_data=ds if args.deepspeed else dl,
            lr_scheduler=sched,
            config_params=deepspeed_config,
        )

    try:
        # Rubik: Define smp.step. Return any tensors needed outside.
        @smp.step
        def train_step(vae, images, temp):
            #             logger.debug(f"args.amp : {args.amp}")
            with autocast(enabled=(args.amp > 0)):
                loss, recons = vae(images,
                                   return_loss=True,
                                   return_recons=True,
                                   temp=temp)

            scaled_loss = scaler.scale(loss) if args.amp else loss
            vae.backward(scaled_loss)
            #             torch.nn.utils.clip_grad_norm_(vae.parameters(), 5)
            return loss, recons

        @smp.step
        def get_codes_step(vae, images, k):
            images = images[:k]
            logits = vae.forward(images, return_logits=True)
            codebook_indices = logits.argmax(dim=1).flatten(1)
            return codebook_indices

        def hard_recons_step(dummy_decoder, dummy_codebook, codebook_indices):
            from functools import partial
            for module in dummy_codebook.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            image_embeds = dummy_codebook.forward(codebook_indices)
            b, n, d = image_embeds.shape
            h = w = int(sqrt(n))

            image_embeds = rearrange(image_embeds,
                                     'b (h w) d -> b d h w',
                                     h=h,
                                     w=w)
            for module in dummy_decoder.modules():
                method = smp_state.patch_manager.get_original_method(
                    "forward", type(module))
                module.forward = partial(method, module)
            hard_recons = dummy_decoder.forward(image_embeds)
            return hard_recons

    except:
        pass

    # starting temperature

    global_step = 0
    temp = STARTING_TEMP

    for epoch in range(EPOCHS):
        ##
        batch_time = util.AverageMeter('Time', ':6.3f')
        data_time = util.AverageMeter('Data', ':6.3f')
        losses = util.AverageMeter('Loss', ':.4e')
        top1 = util.AverageMeter('Acc@1', ':6.2f')
        top5 = util.AverageMeter('Acc@5', ':6.2f')
        progress = util.ProgressMeter(
            len(dl), [batch_time, data_time, losses, top1, top5],
            prefix="Epoch: [{}]".format(epoch))

        vae.train()
        start = time.time()

        for i, (images, _) in enumerate(dl):
            images = images.to(device, non_blocking=True)
            opt.zero_grad()

            if args.model_parallel:
                loss, recons = train_step(vae, images, temp)
                # Rubik: Average the loss across microbatches.
                loss = loss.reduce_mean()
                recons = recons.reduce_mean()
            else:
                loss, recons = distr_vae(images,
                                         return_loss=True,
                                         return_recons=True,
                                         temp=temp)

            if (not args.model_parallel) and args.deepspeed:
                # Gradients are automatically zeroed after the step
                distr_vae.backward(loss)
                distr_vae.step()
            elif args.model_parallel:
                if args.amp:
                    scaler.step(opt)
                    scaler.update()
                else:
                    # some optimizers like adadelta from PT 1.8 dont like it when optimizer.step is called with no param
                    if len(list(vae.local_parameters())) > 0:
                        opt.step()
            else:
                loss.backward()
                opt.step()

            logs = {}

            if i % 10 == 0:
                if args.rank == 0:
                    #                 if deepspeed_utils.is_root_worker():
                    k = NUM_IMAGES_SAVE

                    with torch.no_grad():
                        if args.model_parallel:
                            model_dict = vae.state_dict()
                            model_dict_updated = {}
                            for key, val in model_dict.items():
                                if "decoder" in key:
                                    key = key.replace("decoder.", "")
                                elif "codebook" in key:
                                    key = key.replace("codebook.", "")
                                model_dict_updated[key] = val

                            dummy_decoder.load_state_dict(model_dict_updated,
                                                          strict=False)
                            dummy_codebook.load_state_dict(model_dict_updated,
                                                           strict=False)
                            codes = get_codes_step(vae, images, k)
                            codes = codes.reduce_mean().to(torch.long)
                            hard_recons = hard_recons_step(
                                dummy_decoder, dummy_codebook, codes)
                        else:
                            codes = vae.get_codebook_indices(images[:k])
                            hard_recons = vae.decode(codes)

                    images, recons = map(lambda t: t[:k], (images, recons))
                    images, recons, hard_recons, codes = map(
                        lambda t: t.detach().cpu(),
                        (images, recons, hard_recons, codes))
                    images, recons, hard_recons = map(
                        lambda t: make_grid(t.float(),
                                            nrow=int(sqrt(k)),
                                            normalize=True,
                                            range=(-1, 1)),
                        (images, recons, hard_recons))

#                     logs = {
#                         **logs,
#                         'sample images':        wandb.Image(images, caption = 'original images'),
#                         'reconstructions':      wandb.Image(recons, caption = 'reconstructions'),
#                         'hard reconstructions': wandb.Image(hard_recons, caption = 'hard reconstructions'),
#                         'codebook_indices':     wandb.Histogram(codes),
#                         'temperature':          temp
#                     }

                if args.model_parallel:
                    filename = f'{args.model_dir}/vae.pt'
                    if smp.dp_rank == 0:
                        if args.save_full_model:
                            model_dict = vae.state_dict()
                            opt_dict = opt.state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=False,
                            )
                        else:
                            model_dict = vae.local_state_dict()
                            opt_dict = opt.local_state_dict()
                            smp.save(
                                {
                                    "model_state_dict": model_dict,
                                    "optimizer_state_dict": opt_dict
                                },
                                filename,
                                partial=True,
                            )
                    smp.barrier()

                else:
                    save_model(f'{args.model_dir}/vae.pt')
    #                     wandb.save(f'{args.model_dir}/vae.pt')

    # temperature anneal

                temp = max(temp * math.exp(-ANNEAL_RATE * global_step),
                           TEMP_MIN)

                # lr decay

                sched.step()

            # Collective loss, averaged
            if args.model_parallel:
                avg_loss = loss.detach().clone()
                #                 print("args.world_size : {}".format(args.world_size))
                avg_loss /= args.world_size

            else:
                avg_loss = deepspeed_utils.average_all(loss)

            if args.rank == 0:
                if i % 100 == 0:
                    lr = sched.get_last_lr()[0]
                    print(epoch, i, f'lr - {lr:6f}, loss - {avg_loss.item()},')

                    logs = {
                        **logs, 'epoch': epoch,
                        'iter': i,
                        'loss': avg_loss.item(),
                        'lr': lr
                    }

#                 wandb.log(logs)
            global_step += 1

            if args.rank == 0:
                # Every print_freq iterations, check the loss, accuracy, and speed.
                # For best performance, it doesn't make sense to print these metrics every
                # iteration, since they incur an allreduce and some host<->device syncs.

                # Measure accuracy
                #                 prec1, prec5 = util.accuracy(output, target, topk=(1, 5))

                # to_python_float incurs a host<->device sync
                losses.update(util.to_python_float(loss), images.size(0))
                #                 top1.update(util.to_python_float(prec1), images.size(0))
                #                 top5.update(util.to_python_float(prec5), images.size(0))

                # Waiting until finishing operations on GPU (Pytorch default: async)
                torch.cuda.synchronize()
                batch_time.update((time.time() - start) / args.log_interval)
                end = time.time()

                print(
                    'Epoch: [{0}][{1}/{2}] '
                    'Train_Time={batch_time.val:.3f}: avg-{batch_time.avg:.3f}, '
                    'Train_Speed={3:.3f} ({4:.3f}), '
                    'Train_Loss={loss.val:.10f}:({loss.avg:.4f}),'.format(
                        epoch,
                        i,
                        len(dl),
                        args.world_size * BATCH_SIZE / batch_time.val,
                        args.world_size * BATCH_SIZE / batch_time.avg,
                        batch_time=batch_time,
                        loss=losses))

#         if deepspeed_utils.is_root_worker():
# save trained model to wandb as an artifact every epoch's end

#             model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#             model_artifact.add_file(f'{args.model_dir}/vae.pt')
#             run.log_artifact(model_artifact)

    if args.rank == 0:
        #     if deepspeed_utils.is_root_worker():
        # save final vae and cleanup
        if args.model_parallel:
            logger.debug('save model_parallel')
        else:
            save_model(os.path.join(args.model_dir, 'vae-final.pt'))


#         wandb.save(f'{args.model_dir}/vae-final.pt')

#         model_artifact = wandb.Artifact('trained-vae', type = 'model', metadata = dict(model_config))
#         model_artifact.add_file(f'{args.model_dir}/vae-final.pt')
#         run.log_artifact(model_artifact)

#         wandb.finish()

    if args.model_parallel:
        if args.assert_losses:
            if args.horovod or args.ddp:
                # SM Distributed: If using data parallelism, gather all losses across different model
                # replicas and check if losses match.

                losses = smp.allgather(loss, smp.DP_GROUP)
                for l in losses:
                    print(l)
                    assert math.isclose(l, losses[0])

                assert loss < 0.18
            else:
                assert loss < 0.08

        smp.barrier()
        print("SMP training finished successfully")