def __init__(self, image_size, latent_dim = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, fq_layers = [], fq_dict_size = 256, attn_layers = []): super().__init__() self.lr = lr self.steps = steps self.ema_updater = EMA(0.995) self.S = StyleVectorizer(latent_dim, style_depth) self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent) self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent) self.SE = StyleVectorizer(latent_dim, style_depth) self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent) # experimental contrastive loss discriminator regularization self.D_cl = ContrastiveLearner(self.D, image_size, hidden_layer='flatten') if cl_reg else None set_requires_grad(self.SE, False) set_requires_grad(self.GE, False) generator_params = list(self.G.parameters()) + list(self.S.parameters()) self.G_opt = DiffGrad(generator_params, lr = self.lr, betas=(0.5, 0.9)) self.D_opt = DiffGrad(self.D.parameters(), lr = self.lr, betas=(0.5, 0.9)) self._init_weights() self.reset_parameter_averaging() self.cuda() if fp16: (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O2')
def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False): super().__init__() self.lr = lr self.steps = steps self.ema_updater = EMA(0.995) self.S = StyleVectorizer(latent_dim, style_depth) self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const, fmap_max = fmap_max) self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent, fmap_max = fmap_max) self.SE = StyleVectorizer(latent_dim, style_depth) self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const) # experimental contrastive loss discriminator regularization assert not (transparent and cl_reg), 'contrastive loss regularization does not work with transparent images yet' self.D_cl = ContrastiveLearner(self.D, image_size, hidden_layer='flatten') if cl_reg else None # wrapper for augmenting all images going into the discriminator self.D_aug = AugWrapper(self.D, image_size) set_requires_grad(self.SE, False) set_requires_grad(self.GE, False) generator_params = list(self.G.parameters()) + list(self.S.parameters()) self.G_opt = AdamP(generator_params, lr = self.lr, betas=(0.5, 0.9)) self.D_opt = AdamP(self.D.parameters(), lr = self.lr, betas=(0.5, 0.9)) self._init_weights() self.reset_parameter_averaging() self.cuda() if fp16: (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O2')
def __init__(self, image_size, latent_dim = 512, fmap_max = 512, style_depth = 8, network_capacity = 16, transparent = False, fp16 = False, cl_reg = False, steps = 1, lr = 1e-4, ttur_mult = 2, fq_layers = [], fq_dict_size = 256, attn_layers = [], no_const = False, lr_mlp = 0.1, rank = 0): super().__init__() self.lr = lr self.steps = steps self.ema_updater = EMA(0.995) self.S = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mlp) self.G = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const, fmap_max = fmap_max) self.D = Discriminator(image_size, network_capacity, fq_layers = fq_layers, fq_dict_size = fq_dict_size, attn_layers = attn_layers, transparent = transparent, fmap_max = fmap_max) self.SE = StyleVectorizer(latent_dim, style_depth, lr_mul = lr_mlp) self.GE = Generator(image_size, latent_dim, network_capacity, transparent = transparent, attn_layers = attn_layers, no_const = no_const) self.D_cl = None if cl_reg: from contrastive_learner import ContrastiveLearner # experimental contrastive loss discriminator regularization assert not transparent, 'contrastive loss regularization does not work with transparent images yet' self.D_cl = ContrastiveLearner(self.D, image_size, hidden_layer='flatten') # wrapper for augmenting all images going into the discriminator self.D_aug = AugWrapper(self.D, image_size) # turn off grad for exponential moving averages set_requires_grad(self.SE, False) set_requires_grad(self.GE, False) # init optimizers generator_params = list(self.G.parameters()) + list(self.S.parameters()) self.G_opt = Adam(generator_params, lr = self.lr, betas=(0.5, 0.9)) self.D_opt = Adam(self.D.parameters(), lr = self.lr * ttur_mult, betas=(0.5, 0.9)) # init weights self._init_weights() self.reset_parameter_averaging() self.cuda(rank) # startup apex mixed precision self.fp16 = fp16 if fp16: (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize([self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level='O1', num_losses=3)
def Contrastive(num_classes, loss='triplet', pretrained=True, **kwargs): ConEncoder = torchreid.models.build_model(name='vit_timm_diet', num_classes=num_classes, loss=loss, pretrained=pretrained) #torchreid.utils.load_pretrained_weights(ConEncoder, '/home/danish/deep-person-reid/scripts/log/model/model.pth.tar-112') learner = ContrastiveLearner( ConEncoder, image_size=224, hidden_layer='head', # layer name where output is hidden dimension. this can also be an integer specifying the index of the child project_hidden=True, # use projection head project_dim=128, # projection head dimensions, 128 from paper use_nt_xent_loss=False, # the above mentioned loss, abbreviated temperature=0.1, # temperature augment_both=True # augment both query and key ) return learner
).to(device) g_ema.requires_grad_(False) g_ema.eval() accumulate(g_ema, generator, 0) augment_fn = nn.Sequential( nn.ReflectionPad2d(int((math.sqrt(2) - 1) * args.size / 4)), # zoom out augs.RandomHorizontalFlip(), RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.2), RandomApply(augs.RandomRotation(180), p=0.2), augs.RandomResizedCrop(size=(args.size, args.size), scale=(1, 1), ratio=(1, 1)), RandomApply(augs.RandomResizedCrop(size=(args.size, args.size), scale=(0.5, 0.9)), p=0.1), # zoom in RandomApply(augs.RandomErasing(), p=0.1), ) contrast_learner = ( ContrastiveLearner(discriminator, args.size, augment_fn=augment_fn, hidden_layer=(-1, 0)) if args.contrastive > 0 else None ) g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1) d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1) g_optim = th.optim.Adam( generator.parameters(), lr=args.lr * g_reg_ratio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio), ) d_optim = th.optim.Adam( discriminator.parameters(), lr=args.lr * d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), ) if args.checkpoint is not None:
def __init__( self, image_size, latent_dim=512, style_depth=8, network_capacity=16, transparent=False, fp16=False, cl_reg=False, augment_fn=None, steps=1, lr=1e-4, fq_layers=[], fq_dict_size=256, attn_layers=[], ): super().__init__() self.lr = lr self.steps = steps self.ema_updater = EMA(0.995) self.S = StyleVectorizer(latent_dim, style_depth) self.G = Generator(image_size, latent_dim, network_capacity, transparent=transparent, attn_layers=attn_layers) self.D = Discriminator( image_size, network_capacity, fq_layers=fq_layers, fq_dict_size=fq_dict_size, attn_layers=attn_layers, transparent=transparent, ) self.SE = StyleVectorizer(latent_dim, style_depth) self.GE = Generator(image_size, latent_dim, network_capacity, transparent=transparent, attn_layers=attn_layers) set_requires_grad(self.SE, False) set_requires_grad(self.GE, False) generator_params = list(self.G.parameters()) + list( self.S.parameters()) self.G_opt = DiffGrad(generator_params, lr=self.lr, betas=(0.5, 0.9)) self.D_opt = DiffGrad(self.D.parameters(), lr=self.lr, betas=(0.5, 0.9)) self._init_weights() self.reset_parameter_averaging() self.cuda() if fp16: (self.S, self.G, self.D, self.SE, self.GE), (self.G_opt, self.D_opt) = amp.initialize( [self.S, self.G, self.D, self.SE, self.GE], [self.G_opt, self.D_opt], opt_level="O2") # experimental contrastive loss discriminator regularization if augment_fn is not None: self.augment_fn = augment_fn else: self.augment_fn = nn.Sequential( nn.ReflectionPad2d(int((sqrt(2) - 1) * image_size / 4)), RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.7), augs.RandomGrayscale(p=0.2), augs.RandomHorizontalFlip(), RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.3), RandomApply(nn.Sequential( augs.RandomRotation(180), augs.CenterCrop(size=(image_size, image_size))), p=0.2), augs.RandomResizedCrop(size=(image_size, image_size)), RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), RandomApply(augs.RandomErasing(), p=0.1), ) self.D_cl = (ContrastiveLearner(self.D, image_size, augment_fn=self.augment_fn, fp16=fp16, hidden_layer="flatten") if cl_reg else None)
g_ema.eval() accumulate(g_ema, generator, 0) if args.contrastive > 0: contrast_learner = ContrastiveLearner( discriminator, args.size, augment_fn=nn.Sequential( nn.ReflectionPad2d( int((math.sqrt(2) - 1) * args.size / 4)), # zoom out augs.RandomHorizontalFlip(), RandomApply(augs.RandomAffine(degrees=0, translate=(0.25, 0.25), shear=(15, 15)), p=0.1), RandomApply(augs.RandomRotation(180), p=0.1), augs.RandomResizedCrop(size=(args.size, args.size), scale=(1, 1), ratio=(1, 1)), RandomApply(augs.RandomResizedCrop(size=(args.size, args.size), scale=(0.5, 0.9)), p=0.1), # zoom in RandomApply(augs.RandomErasing(), p=0.1), ), hidden_layer=(-1, 0), ) else: contrast_learner = None g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
def __init__(self, args=args): super().__init__() self.args = args # random_seed setting random_seed = args.randomseed np.random.seed(random_seed) torch.manual_seed(random_seed) if torch.cuda.device_count() > 1: torch.cuda.manual_seed_all(random_seed) else: torch.cuda.manual_seed(random_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False self.pretrain_stage = self.args.pretrainstage self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.slomofc = model.Slomofc( self.args.data_h, self.args.data_w, self.device, self.pretrain_stage ) self.slomofc.to(self.device) if self.pretrain_stage: self.learner = ContrastiveLearner( self.slomofc, image_size=128, hidden_layer="avgpool", use_momentum=True, # use momentum for key encoder momentum_value=0.999, project_hidden=False, # no projection heads use_bilinear=True, # in paper, logits is bilinear product of query / key use_nt_xent_loss=False, # use regular contrastive loss augment_both=False, # in curl, only the key is augmented ) if self.args.init_type != "": init_net(self.slomofc, self.args.init_type) print(self.args.init_type + " initializing slomo done!") if self.args.train_continue: if not self.args.nocomet and self.args.cometid != "": self.comet_exp = ExistingExperiment( previous_experiment=self.args.cometid ) elif not self.args.nocomet and self.args.cometid == "": self.comet_exp = Experiment( workspace=self.args.workspace, project_name=self.args.projectname ) else: self.comet_exp = None self.ckpt_dict = torch.load(self.args.checkpoint) self.slomofc.load_state_dict(self.ckpt_dict["model_state_dict"]) self.args.init_learning_rate = self.ckpt_dict["learningRate"] if not self.pretrain_stage: self.optimizer = optim.Adam( self.slomofc.parameters(), lr=self.args.init_learning_rate ) else: self.optimizer = optim.Adam(self.learner.parameters(), lr=3e-4) self.optimizer.load_state_dict(self.ckpt_dict["opt_state_dict"]) print("Pretrained model loaded!") else: # start logging info in comet-ml if not self.args.nocomet: self.comet_exp = Experiment( workspace=self.args.workspace, project_name=self.args.projectname ) # self.comet_exp.log_parameters(flatten_opts(self.args)) else: self.comet_exp = None if not self.pretrain_stage: self.ckpt_dict = { "trainLoss": {}, "valLoss": {}, "valPSNR": {}, "valSSIM": {}, "learningRate": {}, "epoch": -1, "detail": "End to end Super SloMo.", "trainBatchSz": self.args.train_batch_size, "validationBatchSz": self.args.validation_batch_size, } else: self.ckpt_dict = { "conLoss": {}, "learningRate": {}, "epoch": -1, "detail": "Pretrain_stage of Super SloMo.", "trainBatchSz": self.args.train_batch_size, } if not self.pretrain_stage: self.optimizer = optim.Adam( self.slomofc.parameters(), lr=self.args.init_learning_rate ) else: self.optimizer = optim.Adam(self.learner.parameters(), lr=3e-4) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=self.args.milestones, gamma=0.1 ) # Channel wise mean calculated on adobe240-fps training dataset if not self.pretrain_stage: mean = [0.5, 0.5, 0.5] std = [1, 1, 1] self.normalize = transforms.Normalize(mean=mean, std=std) self.transform = transforms.Compose([transforms.ToTensor(), self.normalize]) else: self.transform = transforms.Compose([transforms.ToTensor()]) trainset = dataloader.SuperSloMo( root=self.args.dataset_root + "/train", transform=self.transform, train=True ) self.trainloader = torch.utils.data.DataLoader( trainset, batch_size=self.args.train_batch_size, num_workers=self.args.num_workers, shuffle=True, ) if not self.pretrain_stage: validationset = dataloader.SuperSloMo( root=self.args.dataset_root + "/validation", transform=self.transform, # randomCropSize=(128, 128), train=False, ) self.validationloader = torch.utils.data.DataLoader( validationset, batch_size=self.args.validation_batch_size, num_workers=self.args.num_workers, shuffle=False, ) ### loss if not self.pretrain_stage: self.supervisedloss = supervisedLoss() self.best = { "valLoss": 99999999, "valPSNR": -1, "valSSIM": -1, } else: self.best = { "conLoss": 99999999, } self.checkpoint_counter = int( (self.ckpt_dict["epoch"] + 1) / self.args.checkpoint_epoch )
class Trainer: def __init__(self, args=args): super().__init__() self.args = args # random_seed setting random_seed = args.randomseed np.random.seed(random_seed) torch.manual_seed(random_seed) if torch.cuda.device_count() > 1: torch.cuda.manual_seed_all(random_seed) else: torch.cuda.manual_seed(random_seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False self.pretrain_stage = self.args.pretrainstage self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.slomofc = model.Slomofc( self.args.data_h, self.args.data_w, self.device, self.pretrain_stage ) self.slomofc.to(self.device) if self.pretrain_stage: self.learner = ContrastiveLearner( self.slomofc, image_size=128, hidden_layer="avgpool", use_momentum=True, # use momentum for key encoder momentum_value=0.999, project_hidden=False, # no projection heads use_bilinear=True, # in paper, logits is bilinear product of query / key use_nt_xent_loss=False, # use regular contrastive loss augment_both=False, # in curl, only the key is augmented ) if self.args.init_type != "": init_net(self.slomofc, self.args.init_type) print(self.args.init_type + " initializing slomo done!") if self.args.train_continue: if not self.args.nocomet and self.args.cometid != "": self.comet_exp = ExistingExperiment( previous_experiment=self.args.cometid ) elif not self.args.nocomet and self.args.cometid == "": self.comet_exp = Experiment( workspace=self.args.workspace, project_name=self.args.projectname ) else: self.comet_exp = None self.ckpt_dict = torch.load(self.args.checkpoint) self.slomofc.load_state_dict(self.ckpt_dict["model_state_dict"]) self.args.init_learning_rate = self.ckpt_dict["learningRate"] if not self.pretrain_stage: self.optimizer = optim.Adam( self.slomofc.parameters(), lr=self.args.init_learning_rate ) else: self.optimizer = optim.Adam(self.learner.parameters(), lr=3e-4) self.optimizer.load_state_dict(self.ckpt_dict["opt_state_dict"]) print("Pretrained model loaded!") else: # start logging info in comet-ml if not self.args.nocomet: self.comet_exp = Experiment( workspace=self.args.workspace, project_name=self.args.projectname ) # self.comet_exp.log_parameters(flatten_opts(self.args)) else: self.comet_exp = None if not self.pretrain_stage: self.ckpt_dict = { "trainLoss": {}, "valLoss": {}, "valPSNR": {}, "valSSIM": {}, "learningRate": {}, "epoch": -1, "detail": "End to end Super SloMo.", "trainBatchSz": self.args.train_batch_size, "validationBatchSz": self.args.validation_batch_size, } else: self.ckpt_dict = { "conLoss": {}, "learningRate": {}, "epoch": -1, "detail": "Pretrain_stage of Super SloMo.", "trainBatchSz": self.args.train_batch_size, } if not self.pretrain_stage: self.optimizer = optim.Adam( self.slomofc.parameters(), lr=self.args.init_learning_rate ) else: self.optimizer = optim.Adam(self.learner.parameters(), lr=3e-4) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, milestones=self.args.milestones, gamma=0.1 ) # Channel wise mean calculated on adobe240-fps training dataset if not self.pretrain_stage: mean = [0.5, 0.5, 0.5] std = [1, 1, 1] self.normalize = transforms.Normalize(mean=mean, std=std) self.transform = transforms.Compose([transforms.ToTensor(), self.normalize]) else: self.transform = transforms.Compose([transforms.ToTensor()]) trainset = dataloader.SuperSloMo( root=self.args.dataset_root + "/train", transform=self.transform, train=True ) self.trainloader = torch.utils.data.DataLoader( trainset, batch_size=self.args.train_batch_size, num_workers=self.args.num_workers, shuffle=True, ) if not self.pretrain_stage: validationset = dataloader.SuperSloMo( root=self.args.dataset_root + "/validation", transform=self.transform, # randomCropSize=(128, 128), train=False, ) self.validationloader = torch.utils.data.DataLoader( validationset, batch_size=self.args.validation_batch_size, num_workers=self.args.num_workers, shuffle=False, ) ### loss if not self.pretrain_stage: self.supervisedloss = supervisedLoss() self.best = { "valLoss": 99999999, "valPSNR": -1, "valSSIM": -1, } else: self.best = { "conLoss": 99999999, } self.checkpoint_counter = int( (self.ckpt_dict["epoch"] + 1) / self.args.checkpoint_epoch ) def train(self): for epoch in range(self.ckpt_dict["epoch"] + 1, self.args.epochs): print("Epoch: ", epoch) if not self.pretrain_stage: print("Training downstream task") print("Training epoch {}".format(epoch)) _, _, train_loss = self.run_epoch( epoch, self.trainloader, logimage=False, isTrain=True, ) with torch.no_grad(): print("Validating epoch {}".format(epoch)) val_psnr, val_ssim, val_loss = self.run_epoch( epoch, self.validationloader, logimage=True, isTrain=False, ) self.ckpt_dict["trainLoss"][str(epoch)] = train_loss self.ckpt_dict["valLoss"][str(epoch)] = val_loss self.ckpt_dict["valPSNR"][str(epoch)] = val_psnr self.ckpt_dict["valSSIM"][str(epoch)] = val_ssim else: print("Training pretrain task") conLoss = 0 for trainIndex, data in enumerate(self.trainloader, 0): loss = self.learner(data) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.learner.update_moving_average() # update moving average of key encoder conLoss += loss.item() conLoss = conLoss / len(self.trainloader) print(" Epoch: %4d conLoss: %0.4f " % (epoch, conLoss,)) self.ckpt_dict["conLoss"][str(epoch)] = conLoss self.comet_exp.log_metric("conLoss", conLoss, epoch=epoch) self.ckpt_dict["learningRate"][str(epoch)] = get_lr(self.optimizer) self.ckpt_dict["epoch"] = epoch self.best = self.save_best(self.ckpt_dict, self.best, epoch) if (epoch % self.args.checkpoint_epoch) == self.args.checkpoint_epoch - 1: self.save() def save_best(self, current, best, epoch, pretrain_stage=True): save_best_done = False metrics = ["conLoss"] if pretrain_stage else ["valLoss", "valSSIM", "valPSNR"] for metric_name in metrics: if not save_best_done: if "Loss" in metric_name: if best[metric_name] > current[metric_name][str(epoch)]: best[metric_name] = current[metric_name][str(epoch)] self.save(metric_name) print( "New Best " + metric_name + ": " + str(best[metric_name]) + "saved" ) save_best_done = True else: if best[metric_name] < current[metric_name][str(epoch)]: best[metric_name] = current[metric_name][str(epoch)] self.save(metric_name) print( "New Best " + metric_name + ": " + str(best[metric_name]) + "saved" ) save_best_done = True return best @torch.no_grad() def save(self, save_metric_name=""): self.ckpt_dict["model_state_dict"] = self.slomofc.state_dict() self.ckpt_dict["opt_state_dict"] = self.optimizer.state_dict() file_name = ( str(self.checkpoint_counter) if save_metric_name == "" else save_metric_name ) model_name = ( "/SuperSloMo" if not self.pretrain_stage else "/Pretrain_stage_SuperSloMo" ) torch.save( self.ckpt_dict, self.args.checkpoint_dir + model_name + file_name + ".ckpt", ) if save_metric_name == "": self.checkpoint_counter += 1 ### Train and Valid def run_epoch(self, epoch, dataloader, logimage=False, isTrain=True): # For details see training. psnr_value = 0 ssim_value = 0 loss_value = 0 if not isTrain: valid_images = [] for index, all_data in enumerate(dataloader, 0): self.optimizer.zero_grad() ( Ft_p, I0, IFrame, I1, g_I0_F_t_0, g_I1_F_t_1, FlowBackWarp_I0_F_1_0, FlowBackWarp_I1_F_0_1, F_1_0, F_0_1, ) = self.slomofc(all_data, pred_only=False, isTrain=isTrain) if (not isTrain) and logimage: if index % self.args.logimagefreq == 0: valid_images.append( 255.0 * I0.cpu()[0] .resize_(1, 1, self.args.data_h, self.args.data_w) .repeat(1, 3, 1, 1) ) valid_images.append( 255.0 * IFrame.cpu()[0] .resize_(1, 1, self.args.data_h, self.args.data_w) .repeat(1, 3, 1, 1) ) valid_images.append( 255.0 * I1.cpu()[0] .resize_(1, 1, self.args.data_h, self.args.data_w) .repeat(1, 3, 1, 1) ) valid_images.append( 255.0 * Ft_p.cpu()[0] .resize_(1, 1, self.args.data_h, self.args.data_w) .repeat(1, 3, 1, 1) ) # loss loss = self.supervisedloss( Ft_p, IFrame, I0, I1, g_I0_F_t_0, g_I1_F_t_1, FlowBackWarp_I0_F_1_0, FlowBackWarp_I1_F_0_1, F_1_0, F_0_1, ) if isTrain: loss.backward() self.optimizer.step() self.scheduler.step() loss_value += loss.item() # metrics psnr_value += psnr(Ft_p, IFrame, outputTensor=False) ssim_value += ssim(Ft_p, IFrame, outputTensor=False) name_loss = "TrainLoss" if isTrain else "ValLoss" itr = int(index + epoch * (len(dataloader))) if self.comet_exp is not None: self.comet_exp.log_metric( "PSNR", psnr_value / len(dataloader), step=itr, epoch=epoch ) self.comet_exp.log_metric( "SSIM", ssim_value / len(dataloader), step=itr, epoch=epoch ) self.comet_exp.log_metric( name_loss, loss_value / len(dataloader), step=itr, epoch=epoch ) if logimage: upload_images( valid_images, epoch, exp=self.comet_exp, im_per_row=4, rows_per_log=int(len(valid_images) / 4), ) print( " Loss: %0.6f Iterations: %4d/%4d ValPSNR: %0.4f ValSSIM: %0.4f " % ( loss_value / len(dataloader), index, len(dataloader), psnr_value / len(dataloader), ssim_value / len(dataloader), ) ) return ( (psnr_value / len(dataloader)), (ssim_value / len(dataloader)), (loss_value / len(dataloader)), )