def __init__(self, image_size, latent_dim = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, fq_layers = [], fq_dict_size = 256, attn_layers = []):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent)
        self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent)

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent)

        # experimental contrastive loss discriminator regularization
        self.D_cl = ContrastiveLearner(self.D, image_size, hidden_layer='flatten') if cl_reg else None

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr = self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(), lr = self.lr, betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()
        
        if fp16:
            (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O2')
Exemple #2
0
    def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const, fmap_max = fmap_max)
        self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent, fmap_max = fmap_max)

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const)

        # experimental contrastive loss discriminator regularization
        assert not (transparent and cl_reg), 'contrastive loss regularization does not work with transparent images yet'
        self.D_cl = ContrastiveLearner(self.D, image_size, hidden_layer='flatten') if cl_reg else None

        # wrapper for augmenting all images going into the discriminator
        self.D_aug = AugWrapper(self.D, image_size)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(self.S.parameters())
        self.G_opt = AdamP(generator_params, lr = self.lr, betas=(0.5, 0.9))
        self.D_opt = AdamP(self.D.parameters(), lr = self.lr, betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()
        
        if fp16:
            (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O2')
    def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, ttur_mult = 2, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, lr_mlp = 0.1, rank = 0):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mlp)
        self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const, fmap_max = fmap_max)
        self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent, fmap_max = fmap_max)

        self.SE = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mlp)
        self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const)

        self.D_cl = None

        if cl_reg:
            from contrastive_learner import ContrastiveLearner
            # experimental contrastive loss discriminator regularization
            assert not transparent, 'contrastive loss regularization does not work with transparent images yet'
            self.D_cl = ContrastiveLearner(self.D, image_size, hidden_layer='flatten')

        # wrapper for augmenting all images going into the discriminator
        self.D_aug = AugWrapper(self.D, image_size)

        # turn off grad for exponential moving averages
        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        # init optimizers
        generator_params = list(self.G.parameters()) + list(self.S.parameters())
        self.G_opt = Adam(generator_params, lr = self.lr, betas=(0.5, 0.9))
        self.D_opt = Adam(self.D.parameters(), lr = self.lr * ttur_mult, betas=(0.5, 0.9))

        # init weights
        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda(rank)

        # startup apex mixed precision
        self.fp16 = fp16
        if fp16:
            (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O1', num_losses=3)
def Contrastive(num_classes, loss='triplet', pretrained=True, **kwargs):

    ConEncoder = torchreid.models.build_model(name='vit_timm_diet',
                                              num_classes=num_classes,
                                              loss=loss,
                                              pretrained=pretrained)

    #torchreid.utils.load_pretrained_weights(ConEncoder, '/home/danish/deep-person-reid/scripts/log/model/model.pth.tar-112')

    learner = ContrastiveLearner(
        ConEncoder,
        image_size=224,
        hidden_layer='head',
        # layer name where output is hidden dimension. this can also be an integer specifying the index of the child
        project_hidden=True,  # use projection head
        project_dim=128,  # projection head dimensions, 128 from paper
        use_nt_xent_loss=False,  # the above mentioned loss, abbreviated
        temperature=0.1,  # temperature
        augment_both=True  # augment both query and key
    )
    return learner
Exemple #5
0
    ).to(device)
    g_ema.requires_grad_(False)
    g_ema.eval()
    accumulate(g_ema, generator, 0)

    augment_fn = nn.Sequential(
        nn.ReflectionPad2d(int((math.sqrt(2) - 1) * args.size / 4)),  # zoom out
        augs.RandomHorizontalFlip(),
        RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.2),
        RandomApply(augs.RandomRotation(180), p=0.2),
        augs.RandomResizedCrop(size=(args.size, args.size), scale=(1, 1), ratio=(1, 1)),
        RandomApply(augs.RandomResizedCrop(size=(args.size, args.size), scale=(0.5, 0.9)), p=0.1),  # zoom in
        RandomApply(augs.RandomErasing(), p=0.1),
    )
    contrast_learner = (
        ContrastiveLearner(discriminator, args.size, augment_fn=augment_fn, hidden_layer=(-1, 0))
        if args.contrastive > 0
        else None
    )

    g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
    d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)

    g_optim = th.optim.Adam(
        generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio),
    )
    d_optim = th.optim.Adam(
        discriminator.parameters(), lr=args.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio),
    )

    if args.checkpoint is not None:
    def __init__(
        self,
        image_size,
        latent_dim=512,
        style_depth=8,
        network_capacity=16,
        transparent=False,
        fp16=False,
        cl_reg=False,
        augment_fn=None,
        steps=1,
        lr=1e-4,
        fq_layers=[],
        fq_dict_size=256,
        attn_layers=[],
    ):
        super().__init__()
        self.lr = lr
        self.steps = steps
        self.ema_updater = EMA(0.995)

        self.S = StyleVectorizer(latent_dim, style_depth)
        self.G = Generator(image_size,
                           latent_dim,
                           network_capacity,
                           transparent=transparent,
                           attn_layers=attn_layers)
        self.D = Discriminator(
            image_size,
            network_capacity,
            fq_layers=fq_layers,
            fq_dict_size=fq_dict_size,
            attn_layers=attn_layers,
            transparent=transparent,
        )

        self.SE = StyleVectorizer(latent_dim, style_depth)
        self.GE = Generator(image_size,
                            latent_dim,
                            network_capacity,
                            transparent=transparent,
                            attn_layers=attn_layers)

        set_requires_grad(self.SE, False)
        set_requires_grad(self.GE, False)

        generator_params = list(self.G.parameters()) + list(
            self.S.parameters())
        self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9))
        self.D_opt = DiffGrad(self.D.parameters(),
                              lr=self.lr,
                              betas=(0.5, 0.9))

        self._init_weights()
        self.reset_parameter_averaging()

        self.cuda()

        if fp16:
            (self.S, self.G, self.D, self.SE,
             self.GE), (self.G_opt, self.D_opt) = amp.initialize(
                 [self.S, self.G, self.D, self.SE, self.GE],
                 [self.G_opt, self.D_opt],
                 opt_level="O2")

        # experimental contrastive loss discriminator regularization
        if augment_fn is not None:
            self.augment_fn = augment_fn
        else:
            self.augment_fn = nn.Sequential(
                nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)),
                RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7),
                augs.RandomGrayscale(p=0.2),
                augs.RandomHorizontalFlip(),
                RandomApply(augs.RandomAffine(degrees=0,
                                              translate=(0.25, 0.25),
                                              shear=(15, 15)),
                            p=0.3),
                RandomApply(nn.Sequential(
                    augs.RandomRotation(180),
                    augs.CenterCrop(size=(image_size, image_size))),
                            p=0.2),
                augs.RandomResizedCrop(size=(image_size, image_size)),
                RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
                RandomApply(augs.RandomErasing(), p=0.1),
            )

        self.D_cl = (ContrastiveLearner(self.D,
                                        image_size,
                                        augment_fn=self.augment_fn,
                                        fp16=fp16,
                                        hidden_layer="flatten")
                     if cl_reg else None)
            g_ema.eval()
            accumulate(g_ema, generator, 0)

            if args.contrastive > 0:
                contrast_learner = ContrastiveLearner(
                    discriminator,
                    args.size,
                    augment_fn=nn.Sequential(
                        nn.ReflectionPad2d(
                            int((math.sqrt(2) - 1) * args.size /
                                4)),  # zoom out
                        augs.RandomHorizontalFlip(),
                        RandomApply(augs.RandomAffine(degrees=0,
                                                      translate=(0.25, 0.25),
                                                      shear=(15, 15)),
                                    p=0.1),
                        RandomApply(augs.RandomRotation(180), p=0.1),
                        augs.RandomResizedCrop(size=(args.size, args.size),
                                               scale=(1, 1),
                                               ratio=(1, 1)),
                        RandomApply(augs.RandomResizedCrop(size=(args.size,
                                                                 args.size),
                                                           scale=(0.5, 0.9)),
                                    p=0.1),  # zoom in
                        RandomApply(augs.RandomErasing(), p=0.1),
                    ),
                    hidden_layer=(-1, 0),
                )
            else:
                contrast_learner = None

            g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
Exemple #8
0
    def __init__(self, args=args):
        super().__init__()
        self.args = args
        # random_seed setting
        random_seed = args.randomseed
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        if torch.cuda.device_count() > 1:
            torch.cuda.manual_seed_all(random_seed)
        else:
            torch.cuda.manual_seed(random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        self.pretrain_stage = self.args.pretrainstage
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.slomofc = model.Slomofc(
            self.args.data_h, self.args.data_w, self.device, self.pretrain_stage
        )
        self.slomofc.to(self.device)
        if self.pretrain_stage:
            self.learner = ContrastiveLearner(
                self.slomofc,
                image_size=128,
                hidden_layer="avgpool",
                use_momentum=True,  # use momentum for key encoder
                momentum_value=0.999,
                project_hidden=False,  # no projection heads
                use_bilinear=True,  # in paper, logits is bilinear product of query / key
                use_nt_xent_loss=False,  # use regular contrastive loss
                augment_both=False,  # in curl, only the key is augmented
            )
        if self.args.init_type != "":
            init_net(self.slomofc, self.args.init_type)
            print(self.args.init_type + " initializing slomo done!")
        if self.args.train_continue:
            if not self.args.nocomet and self.args.cometid != "":
                self.comet_exp = ExistingExperiment(
                    previous_experiment=self.args.cometid
                )
            elif not self.args.nocomet and self.args.cometid == "":
                self.comet_exp = Experiment(
                    workspace=self.args.workspace, project_name=self.args.projectname
                )
            else:
                self.comet_exp = None
            self.ckpt_dict = torch.load(self.args.checkpoint)
            self.slomofc.load_state_dict(self.ckpt_dict["model_state_dict"])
            self.args.init_learning_rate = self.ckpt_dict["learningRate"]
            if not self.pretrain_stage:
                self.optimizer = optim.Adam(
                    self.slomofc.parameters(), lr=self.args.init_learning_rate
                )
            else:
                self.optimizer = optim.Adam(self.learner.parameters(), lr=3e-4)
            self.optimizer.load_state_dict(self.ckpt_dict["opt_state_dict"])
            print("Pretrained model loaded!")
        else:
            # start logging info in comet-ml
            if not self.args.nocomet:
                self.comet_exp = Experiment(
                    workspace=self.args.workspace, project_name=self.args.projectname
                )
                # self.comet_exp.log_parameters(flatten_opts(self.args))
            else:
                self.comet_exp = None
            if not self.pretrain_stage:
                self.ckpt_dict = {
                    "trainLoss": {},
                    "valLoss": {},
                    "valPSNR": {},
                    "valSSIM": {},
                    "learningRate": {},
                    "epoch": -1,
                    "detail": "End to end Super SloMo.",
                    "trainBatchSz": self.args.train_batch_size,
                    "validationBatchSz": self.args.validation_batch_size,
                }
            else:
                self.ckpt_dict = {
                    "conLoss": {},
                    "learningRate": {},
                    "epoch": -1,
                    "detail": "Pretrain_stage of Super SloMo.",
                    "trainBatchSz": self.args.train_batch_size,
                }
            if not self.pretrain_stage:
                self.optimizer = optim.Adam(
                    self.slomofc.parameters(), lr=self.args.init_learning_rate
                )
            else:
                self.optimizer = optim.Adam(self.learner.parameters(), lr=3e-4)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=self.args.milestones, gamma=0.1
        )
        # Channel wise mean calculated on adobe240-fps training dataset
        if not self.pretrain_stage:
            mean = [0.5, 0.5, 0.5]
            std = [1, 1, 1]
            self.normalize = transforms.Normalize(mean=mean, std=std)
            self.transform = transforms.Compose([transforms.ToTensor(), self.normalize])
        else:
            self.transform = transforms.Compose([transforms.ToTensor()])

        trainset = dataloader.SuperSloMo(
            root=self.args.dataset_root + "/train", transform=self.transform, train=True
        )
        self.trainloader = torch.utils.data.DataLoader(
            trainset,
            batch_size=self.args.train_batch_size,
            num_workers=self.args.num_workers,
            shuffle=True,
        )
        if not self.pretrain_stage:
            validationset = dataloader.SuperSloMo(
                root=self.args.dataset_root + "/validation",
                transform=self.transform,
                # randomCropSize=(128, 128),
                train=False,
            )
            self.validationloader = torch.utils.data.DataLoader(
                validationset,
                batch_size=self.args.validation_batch_size,
                num_workers=self.args.num_workers,
                shuffle=False,
            )
        ### loss
        if not self.pretrain_stage:
            self.supervisedloss = supervisedLoss()
            self.best = {
                "valLoss": 99999999,
                "valPSNR": -1,
                "valSSIM": -1,
            }
        else:
            self.best = {
                "conLoss": 99999999,
            }
        self.checkpoint_counter = int(
            (self.ckpt_dict["epoch"] + 1) / self.args.checkpoint_epoch
        )
Exemple #9
0
class Trainer:
    def __init__(self, args=args):
        super().__init__()
        self.args = args
        # random_seed setting
        random_seed = args.randomseed
        np.random.seed(random_seed)
        torch.manual_seed(random_seed)
        if torch.cuda.device_count() > 1:
            torch.cuda.manual_seed_all(random_seed)
        else:
            torch.cuda.manual_seed(random_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        self.pretrain_stage = self.args.pretrainstage
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.slomofc = model.Slomofc(
            self.args.data_h, self.args.data_w, self.device, self.pretrain_stage
        )
        self.slomofc.to(self.device)
        if self.pretrain_stage:
            self.learner = ContrastiveLearner(
                self.slomofc,
                image_size=128,
                hidden_layer="avgpool",
                use_momentum=True,  # use momentum for key encoder
                momentum_value=0.999,
                project_hidden=False,  # no projection heads
                use_bilinear=True,  # in paper, logits is bilinear product of query / key
                use_nt_xent_loss=False,  # use regular contrastive loss
                augment_both=False,  # in curl, only the key is augmented
            )
        if self.args.init_type != "":
            init_net(self.slomofc, self.args.init_type)
            print(self.args.init_type + " initializing slomo done!")
        if self.args.train_continue:
            if not self.args.nocomet and self.args.cometid != "":
                self.comet_exp = ExistingExperiment(
                    previous_experiment=self.args.cometid
                )
            elif not self.args.nocomet and self.args.cometid == "":
                self.comet_exp = Experiment(
                    workspace=self.args.workspace, project_name=self.args.projectname
                )
            else:
                self.comet_exp = None
            self.ckpt_dict = torch.load(self.args.checkpoint)
            self.slomofc.load_state_dict(self.ckpt_dict["model_state_dict"])
            self.args.init_learning_rate = self.ckpt_dict["learningRate"]
            if not self.pretrain_stage:
                self.optimizer = optim.Adam(
                    self.slomofc.parameters(), lr=self.args.init_learning_rate
                )
            else:
                self.optimizer = optim.Adam(self.learner.parameters(), lr=3e-4)
            self.optimizer.load_state_dict(self.ckpt_dict["opt_state_dict"])
            print("Pretrained model loaded!")
        else:
            # start logging info in comet-ml
            if not self.args.nocomet:
                self.comet_exp = Experiment(
                    workspace=self.args.workspace, project_name=self.args.projectname
                )
                # self.comet_exp.log_parameters(flatten_opts(self.args))
            else:
                self.comet_exp = None
            if not self.pretrain_stage:
                self.ckpt_dict = {
                    "trainLoss": {},
                    "valLoss": {},
                    "valPSNR": {},
                    "valSSIM": {},
                    "learningRate": {},
                    "epoch": -1,
                    "detail": "End to end Super SloMo.",
                    "trainBatchSz": self.args.train_batch_size,
                    "validationBatchSz": self.args.validation_batch_size,
                }
            else:
                self.ckpt_dict = {
                    "conLoss": {},
                    "learningRate": {},
                    "epoch": -1,
                    "detail": "Pretrain_stage of Super SloMo.",
                    "trainBatchSz": self.args.train_batch_size,
                }
            if not self.pretrain_stage:
                self.optimizer = optim.Adam(
                    self.slomofc.parameters(), lr=self.args.init_learning_rate
                )
            else:
                self.optimizer = optim.Adam(self.learner.parameters(), lr=3e-4)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=self.args.milestones, gamma=0.1
        )
        # Channel wise mean calculated on adobe240-fps training dataset
        if not self.pretrain_stage:
            mean = [0.5, 0.5, 0.5]
            std = [1, 1, 1]
            self.normalize = transforms.Normalize(mean=mean, std=std)
            self.transform = transforms.Compose([transforms.ToTensor(), self.normalize])
        else:
            self.transform = transforms.Compose([transforms.ToTensor()])

        trainset = dataloader.SuperSloMo(
            root=self.args.dataset_root + "/train", transform=self.transform, train=True
        )
        self.trainloader = torch.utils.data.DataLoader(
            trainset,
            batch_size=self.args.train_batch_size,
            num_workers=self.args.num_workers,
            shuffle=True,
        )
        if not self.pretrain_stage:
            validationset = dataloader.SuperSloMo(
                root=self.args.dataset_root + "/validation",
                transform=self.transform,
                # randomCropSize=(128, 128),
                train=False,
            )
            self.validationloader = torch.utils.data.DataLoader(
                validationset,
                batch_size=self.args.validation_batch_size,
                num_workers=self.args.num_workers,
                shuffle=False,
            )
        ### loss
        if not self.pretrain_stage:
            self.supervisedloss = supervisedLoss()
            self.best = {
                "valLoss": 99999999,
                "valPSNR": -1,
                "valSSIM": -1,
            }
        else:
            self.best = {
                "conLoss": 99999999,
            }
        self.checkpoint_counter = int(
            (self.ckpt_dict["epoch"] + 1) / self.args.checkpoint_epoch
        )

    def train(self):
        for epoch in range(self.ckpt_dict["epoch"] + 1, self.args.epochs):
            print("Epoch: ", epoch)
            if not self.pretrain_stage:
                print("Training downstream task")
                print("Training epoch {}".format(epoch))
                _, _, train_loss = self.run_epoch(
                    epoch, self.trainloader, logimage=False, isTrain=True,
                )
                with torch.no_grad():
                    print("Validating epoch {}".format(epoch))
                    val_psnr, val_ssim, val_loss = self.run_epoch(
                        epoch, self.validationloader, logimage=True, isTrain=False,
                    )
                self.ckpt_dict["trainLoss"][str(epoch)] = train_loss
                self.ckpt_dict["valLoss"][str(epoch)] = val_loss
                self.ckpt_dict["valPSNR"][str(epoch)] = val_psnr
                self.ckpt_dict["valSSIM"][str(epoch)] = val_ssim
            else:
                print("Training pretrain task")
                conLoss = 0
                for trainIndex, data in enumerate(self.trainloader, 0):
                    loss = self.learner(data)
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    self.learner.update_moving_average()  # update moving average of key encoder
                    conLoss += loss.item()
                conLoss = conLoss / len(self.trainloader)
                print(" Epoch: %4d  conLoss: %0.4f  " % (epoch, conLoss,))
                self.ckpt_dict["conLoss"][str(epoch)] = conLoss
                self.comet_exp.log_metric("conLoss", conLoss, epoch=epoch)

            self.ckpt_dict["learningRate"][str(epoch)] = get_lr(self.optimizer)
            self.ckpt_dict["epoch"] = epoch
            self.best = self.save_best(self.ckpt_dict, self.best, epoch)
            if (epoch % self.args.checkpoint_epoch) == self.args.checkpoint_epoch - 1:
                self.save()

    def save_best(self, current, best, epoch, pretrain_stage=True):
        save_best_done = False
        metrics = ["conLoss"] if pretrain_stage else ["valLoss", "valSSIM", "valPSNR"]
        for metric_name in metrics:
            if not save_best_done:
                if "Loss" in metric_name:
                    if best[metric_name] > current[metric_name][str(epoch)]:
                        best[metric_name] = current[metric_name][str(epoch)]
                        self.save(metric_name)
                        print(
                            "New Best "
                            + metric_name
                            + ": "
                            + str(best[metric_name])
                            + "saved"
                        )
                        save_best_done = True
                else:
                    if best[metric_name] < current[metric_name][str(epoch)]:
                        best[metric_name] = current[metric_name][str(epoch)]
                        self.save(metric_name)
                        print(
                            "New Best "
                            + metric_name
                            + ": "
                            + str(best[metric_name])
                            + "saved"
                        )
                        save_best_done = True
        return best

    @torch.no_grad()
    def save(self, save_metric_name=""):
        self.ckpt_dict["model_state_dict"] = self.slomofc.state_dict()
        self.ckpt_dict["opt_state_dict"] = self.optimizer.state_dict()
        file_name = (
            str(self.checkpoint_counter) if save_metric_name == "" else save_metric_name
        )
        model_name = (
            "/SuperSloMo" if not self.pretrain_stage else "/Pretrain_stage_SuperSloMo"
        )
        torch.save(
            self.ckpt_dict, self.args.checkpoint_dir + model_name + file_name + ".ckpt",
        )
        if save_metric_name == "":
            self.checkpoint_counter += 1

    ### Train and Valid
    def run_epoch(self, epoch, dataloader, logimage=False, isTrain=True):
        # For details see training.
        psnr_value = 0
        ssim_value = 0
        loss_value = 0
        if not isTrain:
            valid_images = []
        for index, all_data in enumerate(dataloader, 0):
            self.optimizer.zero_grad()
            (
                Ft_p,
                I0,
                IFrame,
                I1,
                g_I0_F_t_0,
                g_I1_F_t_1,
                FlowBackWarp_I0_F_1_0,
                FlowBackWarp_I1_F_0_1,
                F_1_0,
                F_0_1,
            ) = self.slomofc(all_data, pred_only=False, isTrain=isTrain)
            if (not isTrain) and logimage:
                if index % self.args.logimagefreq == 0:
                    valid_images.append(
                        255.0
                        * I0.cpu()[0]
                        .resize_(1, 1, self.args.data_h, self.args.data_w)
                        .repeat(1, 3, 1, 1)
                    )
                    valid_images.append(
                        255.0
                        * IFrame.cpu()[0]
                        .resize_(1, 1, self.args.data_h, self.args.data_w)
                        .repeat(1, 3, 1, 1)
                    )
                    valid_images.append(
                        255.0
                        * I1.cpu()[0]
                        .resize_(1, 1, self.args.data_h, self.args.data_w)
                        .repeat(1, 3, 1, 1)
                    )
                    valid_images.append(
                        255.0
                        * Ft_p.cpu()[0]
                        .resize_(1, 1, self.args.data_h, self.args.data_w)
                        .repeat(1, 3, 1, 1)
                    )
            # loss
            loss = self.supervisedloss(
                Ft_p,
                IFrame,
                I0,
                I1,
                g_I0_F_t_0,
                g_I1_F_t_1,
                FlowBackWarp_I0_F_1_0,
                FlowBackWarp_I1_F_0_1,
                F_1_0,
                F_0_1,
            )
            if isTrain:
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()

            loss_value += loss.item()

            # metrics
            psnr_value += psnr(Ft_p, IFrame, outputTensor=False)
            ssim_value += ssim(Ft_p, IFrame, outputTensor=False)

        name_loss = "TrainLoss" if isTrain else "ValLoss"
        itr = int(index + epoch * (len(dataloader)))
        if self.comet_exp is not None:
            self.comet_exp.log_metric(
                "PSNR", psnr_value / len(dataloader), step=itr, epoch=epoch
            )
            self.comet_exp.log_metric(
                "SSIM", ssim_value / len(dataloader), step=itr, epoch=epoch
            )
            self.comet_exp.log_metric(
                name_loss, loss_value / len(dataloader), step=itr, epoch=epoch
            )
            if logimage:
                upload_images(
                    valid_images,
                    epoch,
                    exp=self.comet_exp,
                    im_per_row=4,
                    rows_per_log=int(len(valid_images) / 4),
                )
        print(
            " Loss: %0.6f  Iterations: %4d/%4d  ValPSNR: %0.4f  ValSSIM: %0.4f "
            % (
                loss_value / len(dataloader),
                index,
                len(dataloader),
                psnr_value / len(dataloader),
                ssim_value / len(dataloader),
            )
        )
        return (
            (psnr_value / len(dataloader)),
            (ssim_value / len(dataloader)),
            (loss_value / len(dataloader)),
        )