class TransformingAutoEncoderController(nn.Module, ControllableModel): ''' Transforming Auto-Encoder. References: - Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473. - Floridi, L., & Chiriatti, M. (2020). GPT-3: Its nature, scope, limits, and consequences. Minds and Machines, 30(4), 681-694. - Miller, A., Fisch, A., Dodge, J., Karimi, A. H., Bordes, A., & Weston, J. (2016). Key-value memory networks for directly reading documents. arXiv preprint arXiv:1606.03126. - Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018) Improving Language Understanding by Generative Pre-Training. OpenAI (URL: https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) - Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9. - Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762. ''' __loaded_filename = None __loaded_ctx = None # `bool` that means initialization in this class will be deferred or not. __init_deferred_flag = False def __init__( self, computable_loss=None, optimizer_f=None, encoder=None, decoder=None, reconstructor=None, layer_n=3, head_n=3, seq_len=5, depth_dim=100, hidden_dim=100, self_attention_activation_list=[], multi_head_attention_activation_list=[], fc_activation_list=[], learning_rate=1e-05, weight_decay=0.01, dropout_rate=0.5, ctx="cpu", regularizatable_data_list=[], not_init_flag=False, ): ''' Init. Args: computable_loss: is-a `ComputableLoss` or `gluon.loss`. encoder: is-a `TransformerModel`. decoder: is-a `TransformerModel`. reconstructor: is-a `TransformerModel`. layer_n: `int` of the number of layers. head_n: `int` of the number of heads for multi-head attention model. seq_len: `int` of the length of sequences. depth_dim: `int` of dimension of dense layer. hidden_dim: `int` of dimension of hidden(encoder) layer. self_attention_activation_list: `list` of `str` of activation function for self-attention model. multi_head_attention_activation_list: `list` of `str` of activation function for multi-head attention model. fc_activation_list: `list` of `str` of activation function in fully-connected layers. learning_rate: `float` of learning rate. learning_attenuate_rate: `float` of attenuate the `learning_rate` by a factor of this value every `attenuate_epoch`. attenuate_epoch: `int` of attenuate the `learning_rate` by a factor of `learning_attenuate_rate` every `attenuate_epoch`. optimizer_name: `str` of name of optimizer. hybridize_flag: Call `mxnet.gluon.HybridBlock.hybridize()` or not. scale: `float` of scaling factor for initial parameters. ctx: `mx.cpu()` or `mx.gpu()`. initializer: is-a `mxnet.initializer` for parameters of model. If `None`, it is drawing from the Xavier distribution. ''' super(TransformingAutoEncoderController, self).__init__() if computable_loss is None: computable_loss = nn.CrossEntropyLoss() self.__computable_loss = computable_loss if encoder is None: if hidden_dim is None or hidden_dim == depth_dim: encoder = TransformerEncoder( depth_dim=depth_dim, layer_n=layer_n, head_n=head_n, self_attention_activation_list= self_attention_activation_list, fc_activation_list=fc_activation_list, computable_loss=computable_loss, not_init_flag=not_init_flag, dropout_rate=dropout_rate, ctx=ctx) else: encoder = TransformerEncoder( depth_dim=hidden_dim, layer_n=layer_n, head_n=head_n, self_attention_activation_list= self_attention_activation_list, fc_activation_list=fc_activation_list, computable_loss=computable_loss, not_init_flag=not_init_flag, dropout_rate=dropout_rate, ctx=ctx) encoder.embedding_flag = False else: if isinstance(encoder, TransformerModel) is False: raise TypeError( "The type of `encoder` must be `TransformerModel`.") if decoder is None: if hidden_dim is None or hidden_dim == depth_dim: decoder = TransformerDecoder( head_n=head_n, depth_dim=depth_dim, layer_n=layer_n, self_attention_activation_list= self_attention_activation_list, multi_head_attention_activation_list= multi_head_attention_activation_list, fc_activation_list=fc_activation_list, computable_loss=computable_loss, not_init_flag=not_init_flag, ctx=ctx) else: decoder = TransformerDecoder( head_n=head_n, depth_dim=hidden_dim, output_dim=hidden_dim, layer_n=layer_n, self_attention_activation_list= self_attention_activation_list, multi_head_attention_activation_list= multi_head_attention_activation_list, fc_activation_list=fc_activation_list, computable_loss=computable_loss, not_init_flag=not_init_flag, ctx=ctx) decoder.embedding_flag = False else: if isinstance(decoder, TransformerModel) is False: raise TypeError( "The type of `decoder` must be `TransformerModel`.") if reconstructor is None: if hidden_dim is None or hidden_dim == depth_dim: reconstructor = TransformerReconstructor( head_n=head_n, depth_dim=depth_dim, layer_n=layer_n, self_attention_activation_list= self_attention_activation_list, fc_activation_list=fc_activation_list, computable_loss=computable_loss, not_init_flag=not_init_flag, ctx=ctx, ) else: reconstructor = TransformerReconstructor( head_n=head_n, depth_dim=hidden_dim, output_dim=depth_dim, layer_n=layer_n, self_attention_activation_list= self_attention_activation_list, fc_activation_list=fc_activation_list, computable_loss=computable_loss, not_init_flag=not_init_flag, ctx=ctx, ) reconstructor.embedding_flag = False else: if isinstance(reconstructor, TransformerModel) is False: raise TypeError( "The type of `reconstructor` must be `TransformerModel`.") logger = getLogger("accelbrainbase") self.logger = logger self.encoder = encoder self.decoder = decoder self.reconstructor = reconstructor if hidden_dim is not None and hidden_dim != depth_dim: self.encoder_hidden_fc = nn.Linear( depth_dim, hidden_dim, bias=True, ) self.decoder_hidden_fc = nn.Linear( depth_dim, hidden_dim, bias=True, ) init_flag = True else: self.encoder_hidden_fc = None self.decoder_hidden_fc = None init_flag = False self.__ctx = ctx self.to(self.__ctx) if self.init_deferred_flag is False: if not_init_flag is False: if optimizer_f is not None: if init_flag is True: self.optimizer = optimizer_f(self.parameters(), ) else: if init_flag is True: self.optimizer = AdamW(self.parameters(), lr=learning_rate, weight_decay=weight_decay) for v in regularizatable_data_list: if isinstance(v, RegularizatableData) is False: raise TypeError( "The type of values of `regularizatable_data_list` must be `RegularizatableData`." ) self.__regularizatable_data_list = regularizatable_data_list self.__learning_rate = learning_rate self.__weight_decay = weight_decay self.seq_len = seq_len self.epoch = 0 def learn(self, iteratable_data): ''' Learn samples drawn by `IteratableData.generate_learned_samples()`. Args: iteratable_data: is-a `TransformerIterator`. ''' if isinstance(iteratable_data, TransformerIterator) is False: raise TypeError( "The type of `iteratable_data` must be `TransformerIterator`.") self.__loss_list = [] learning_rate = self.__learning_rate try: epoch = self.epoch iter_n = 0 for encoded_observed_arr, decoded_observed_arr, encoded_mask_arr, decoded_mask_arr, test_encoded_observed_arr, test_decoded_observed_arr, test_encoded_mask_arr, test_decoded_mask_arr, training_target_arr, test_target_arr in iteratable_data.generate_learned_samples( ): self.epoch = epoch if self.encoder.optimizer is not None and self.decoder.optimizer is not None: optimizer_setup_flag = True self.encoder.optimizer.zero_grad() self.decoder.optimizer.zero_grad() self.optimizer.zero_grad() else: optimizer_setup_flag = False pred_arr = self.inference(encoded_observed_arr, decoded_observed_arr, encoded_mask_arr, decoded_mask_arr) loss = self.compute_loss(pred_arr, training_target_arr) if optimizer_setup_flag is False: self.encoder.optimizer.zero_grad() self.decoder.optimizer.zero_grad() self.optimizer.zero_grad() pred_arr = self.inference(encoded_observed_arr, decoded_observed_arr, encoded_mask_arr, decoded_mask_arr) loss = self.compute_loss(pred_arr, training_target_arr) loss.backward() self.optimizer.step() self.decoder.optimizer.step() self.encoder.optimizer.step() self.regularize() self.decoder.regularize() self.encoder.regularize() if (iter_n + 1) % int( iteratable_data.iter_n / iteratable_data.epochs) == 0: if torch.inference_mode(): test_pred_arr = self.inference( test_encoded_observed_arr, test_decoded_observed_arr, test_encoded_mask_arr, test_decoded_mask_arr) test_loss = self.compute_loss(test_pred_arr, test_target_arr) _loss = loss.to('cpu').detach().numpy().copy() _test_loss = test_loss.to('cpu').detach().numpy().copy() self.__loss_list.append((_loss, _test_loss)) self.logger.debug("Epochs: " + str(epoch + 1) + " Train loss: " + str(_loss) + " Test loss: " + str(_test_loss)) epoch += 1 iter_n += 1 except KeyboardInterrupt: self.logger.debug("Interrupt.") self.logger.debug("end. ") self.epoch = epoch def inference( self, encoded_observed_arr, decoded_observed_arr, encoded_mask_arr=None, decoded_mask_arr=None, ): ''' Inference samples drawn by `IteratableData.generate_inferenced_samples()`. Args: encoded_observed_arr: rank-3 Array like or sparse matrix as the observed data points. The shape is: (batch size, the length of sequence, feature points) decoded_observed_arr: rank-3 Array like or sparse matrix as the observed data points. The shape is: (batch size, the length of sequence, feature points) encoded_mask_arr: rank-3 Array like or sparse matrix as the observed data points. The shape is: (batch size, the length of sequence, feature points) decoded_mask_arr: rank-3 Array like or sparse matrix as the observed data points. The shape is: (batch size, the length of sequence, feature points) Returns: `mxnet.ndarray` of inferenced feature points. ''' return self( encoded_observed_arr, decoded_observed_arr, encoded_mask_arr, decoded_mask_arr, ) def compute_loss(self, pred_arr, labeled_arr): ''' Compute loss. Args: pred_arr: `mxnet.ndarray` or `mxnet.symbol`. labeled_arr: `mxnet.ndarray` or `mxnet.symbol`. Returns: loss. ''' return self.__computable_loss(pred_arr, labeled_arr) def forward( self, encoded_observed_arr, decoded_observed_arr, encoded_mask_arr=None, decoded_mask_arr=None, ): ''' Hybrid forward with Gluon API. Args: F: `mxnet.ndarray` or `mxnet.symbol`. encoded_observed_arr: rank-3 Array like or sparse matrix as the observed data points. The shape is: (batch size, the length of sequence, feature points) decoded_observed_arr: rank-3 Array like or sparse matrix as the observed data points. The shape is: (batch size, the length of sequence, feature points) encoded_mask_arr: rank-3 Array like or sparse matrix as the observed data points. The shape is: (batch size, the length of sequence, feature points) decoded_mask_arr: rank-3 Array like or sparse matrix as the observed data points. The shape is: (batch size, the length of sequence, feature points) Returns: `mxnet.ndarray` or `mxnet.symbol` of inferenced feature points. ''' if self.__loaded_filename is not None: loaded_filename = self.__loaded_filename self.__loaded_filename = None init_encoded_observed_arr = encoded_observed_arr.detach() init_decoded_observed_arr = decoded_observed_arr.detach() if encoded_mask_arr is not None: init_encoded_mask_arr = encoded_mask_arr.detach() else: init_encoded_mask_arr = None if decoded_mask_arr is not None: init_decoded_mask_arr = decoded_mask_arr.detach() else: init_decoded_mask_arr = decoded_mask_arr _ = self.forward( init_encoded_observed_arr, init_decoded_observed_arr, init_encoded_mask_arr, init_decoded_mask_arr, ) self.load_parameters(loaded_filename, ctx=self.__loaded_ctx) self.__loaded_ctx = None if encoded_mask_arr is None: encoded_mask_arr = torch.ones( (encoded_observed_arr.shape[0], 1, 1, 1), ) encoded_mask_arr = encoded_mask_arr.to(encoded_observed_arr.device) if decoded_mask_arr is None: decoded_mask_arr = torch.ones( (decoded_observed_arr.shape[0], 1, 1, 1), ) decoded_mask_arr = decoded_mask_arr.to(decoded_observed_arr.device) if self.encoder_hidden_fc is not None: encoded_observed_arr = self.encoder_hidden_fc(encoded_observed_arr) if self.decoder_hidden_fc is not None: decoded_observed_arr = self.decoder_hidden_fc(decoded_observed_arr) encoded_arr = self.encoder(encoded_observed_arr, encoded_mask_arr) decoded_arr = self.decoder( decoded_observed_arr, encoded_arr, decoded_mask_arr, encoded_mask_arr, ) self.feature_points_arr = decoded_arr reconstructed_arr = self.reconstructor(decoded_arr, None) return reconstructed_arr def extract_learned_dict(self): ''' Extract (pre-) learned parameters. Returns: `dict` of the parameters. ''' params_dict = {} for k in self.state_dict().keys(): params_dict.setdefault(k, self.state_dict()[k]) return params_dict def regularize(self): ''' Regularization. ''' if len(self.__regularizatable_data_list) > 0: params_dict = self.extract_learned_dict() for regularizatable in self.__regularizatable_data_list: params_dict = regularizatable.regularize(params_dict) for k, params in params_dict.items(): self.load_state_dict({k: params}, strict=False) def __rename_file(self, filename): filename_list = filename.split(".") _format = filename_list[-1] encoder_filename = filename.replace("." + _format, "_encoder." + _format) decoder_filename = filename.replace("." + _format, "_decoder." + _format) reconstructor_filename = filename.replace("." + _format, "_reconstructor." + _format) return encoder_filename, decoder_filename, reconstructor_filename def save_parameters(self, filename): ''' Save parameters to files. Args: filename: File name. ''' encoder_filename, decoder_filename, reconstructor_filename = self.__rename_file( filename) self.encoder.epoch = self.epoch self.encoder.loss_arr = self.loss_arr self.encoder.save_parameters(encoder_filename) self.decoder.epoch = self.epoch self.decoder.loss_arr = self.loss_arr self.decoder.save_parameters(decoder_filename) self.reconstructor.epoch = self.epoch self.reconstructor.loss_arr = self.loss_arr self.reconstructor.save_parameters(decoder_filename) torch.save( { 'model_state_dict': self.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'epoch': self.epoch, 'loss': self.loss_arr, }, filename) def load_parameters(self, filename, ctx=None, strict=True): ''' Load parameters to files. Args: filename: File name. ctx: Context-manager that changes the selected device. strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`. ''' try: encoder_filename, decoder_filename, reconstructor_filename = self.__rename_file( filename) self.encoder.load_parameters(encoder_filename, ctx=ctx, strict=strict) self.decoder.load_parameters(decoder_filename, ctx=ctx, strict=strict) self.reconstructor.load_parameters(reconstructor_filename, ctx=ctx, strict=strict) checkpoint = torch.load(filename) self.load_state_dict(checkpoint['model_state_dict'], strict=strict) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.epoch = checkpoint['epoch'] self.__loss_list = checkpoint['loss'].tolist() except RuntimeError: self.__loaded_filename = filename self.__loaded_ctx = ctx if ctx is not None: self.to(ctx) self.encoder.to(ctx) self.decoder.to(ctx) self.reconstructor.to(ctx) self.__ctx = ctx def set_readonly(self, value): ''' setter ''' raise TypeError("This property must be read-only.") def get_init_deferred_flag(self): ''' getter for `bool` that means initialization in this class will be deferred or not.''' return self.__init_deferred_flag def set_init_deferred_flag(self, value): ''' setter for `bool` that means initialization in this class will be deferred or not.''' self.__init_deferred_flag = value init_deferred_flag = property(get_init_deferred_flag, set_init_deferred_flag) __loss_list = [] def get_loss_arr(self): ''' getter ''' return np.array(self.__loss_list) def set_loss_arr(self, value): ''' setter ''' raise TypeError("This property must be read-only.") loss_arr = property(get_loss_arr, set_loss_arr)
def train(): global writer # For parsing commandline arguments parser = argparse.ArgumentParser() parser.add_argument( "--dataset_root", type=str, required=True, help='path to dataset folder containing train-test-validation folders') parser.add_argument("--checkpoint_dir", type=str, required=True, help='path to folder for saving checkpoints') parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model') parser.add_argument( "--train_continue", type=bool, default=False, help= 'If resuming from checkpoint, set to True and set `checkpoint` path. Default: False.' ) parser.add_argument("--epochs", type=int, default=200, help='number of epochs to train. Default: 200.') parser.add_argument("--train_batch_size", type=int, default=3, help='batch size for training. Default: 6.') parser.add_argument("--validation_batch_size", type=int, default=6, help='batch size for validation. Default: 10.') parser.add_argument("--init_learning_rate", type=float, default=0.0001, help='set initial learning rate. Default: 0.0001.') parser.add_argument( "--milestones", type=list, default=[25, 50], help= 'UNUSED NOW: Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]' ) parser.add_argument( "--progress_iter", type=int, default=200, help= 'frequency of reporting progress and validation. N: after every N iterations. Default: 100.' ) parser.add_argument( "--checkpoint_epoch", type=int, default=5, help= 'checkpoint saving frequency. N: after every N epochs. Each checkpoint is roughly of size 151 MB.Default: 5.' ) args = parser.parse_args() ##[TensorboardX](https://github.com/lanpa/tensorboardX) ### For visualizing loss and interpolated frames ###Initialize flow computation and arbitrary-time flow interpolation CNNs. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) flowComp = model.UNet(6, 4) flowComp.to(device) ArbTimeFlowIntrp = model.UNet(20, 5) ArbTimeFlowIntrp.to(device) ###Initialze backward warpers for train and validation datasets train_W_dim = 352 train_H_dim = 352 trainFlowBackWarp = model.backWarp(train_W_dim, train_H_dim, device) trainFlowBackWarp = trainFlowBackWarp.to(device) validationFlowBackWarp = model.backWarp(train_W_dim * 2, train_H_dim, device) validationFlowBackWarp = validationFlowBackWarp.to(device) ###Load Datasets # Channel wise mean calculated on custom training dataset # mean = [0.43702903766008444, 0.43715053433990597, 0.40436416782660994] mean = [0.5] * 3 std = [1, 1, 1] normalize = transforms.Normalize(mean=mean, std=std) transform = transforms.Compose([transforms.ToTensor(), normalize]) trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train', randomCropSize=(train_W_dim, train_H_dim), transform=transform, train=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=True, num_workers=2, pin_memory=True) validationset = dataloader.SuperSloMo( root=args.dataset_root + '/validation', transform=transform, randomCropSize=(2 * train_W_dim, train_H_dim), train=False) validationloader = torch.utils.data.DataLoader( validationset, batch_size=args.validation_batch_size, shuffle=False, num_workers=2, pin_memory=True) print(trainset, validationset) ###Create transform to display image from tensor negmean = [x * -1 for x in mean] revNormalize = transforms.Normalize(mean=negmean, std=std) TP = transforms.Compose([revNormalize, transforms.ToPILImage()]) ###Utils def get_lr(optimizer): for param_group in optimizer.param_groups: return param_group['lr'] ###Loss and Optimizer L1_lossFn = nn.L1Loss() MSE_LossFn = nn.MSELoss() if args.train_continue: dict1 = torch.load(args.checkpoint) last_epoch = dict1['epoch'] * len(trainloader) else: last_epoch = -1 params = list(ArbTimeFlowIntrp.parameters()) + list(flowComp.parameters()) optimizer = AdamW(params, lr=args.init_learning_rate, amsgrad=True) # optimizer = optim.SGD(params, lr=args.init_learning_rate, momentum=0.9, nesterov=True) # scheduler to decrease learning rate by a factor of 10 at milestones. # Patience suggested value: # patience = number of item in train dataset / train_batch_size * (Number of epochs patience) # It does say epoch, but in this case, the number of progress iterations is what's really being worked with. # As such, each epoch will be given by the above formula (roughly, if using a rough dataset count) # If the model seems to equalize fast, reduce the number of epochs accordingly. # scheduler = optim.lr_scheduler.CyclicLR(optimizer, # base_lr=1e-8, # max_lr=9.0e-3, # step_size_up=3500, # mode='triangular2', # cycle_momentum=False, # last_epoch=last_epoch) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.1, patience=len(trainloader) * 3, cooldown=len(trainloader) * 2, verbose=True, min_lr=1e-8) # Changed to use this to ensure a more adaptive model. # The changed model used here seems to converge or plateau faster with more rapid swings over time. # As such letting the model deal with stagnation more proactively than at a set stage seems more useful. ###Initializing VGG16 model for perceptual loss vgg16 = torchvision.models.vgg16(pretrained=True) vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22]) vgg16_conv_4_3.to(device) for param in vgg16_conv_4_3.parameters(): param.requires_grad = False # Validation function def validate(): # For details see training. psnr = 0 tloss = 0 flag = 1 with torch.no_grad(): for validationIndex, (validationData, validationFrameIndex) in enumerate( validationloader, 0): frame0, frameT, frame1 = validationData I0 = frame0.to(device) I1 = frame1.to(device) IFrame = frameT.to(device) torch.cuda.empty_cache() flowOut = flowComp(torch.cat((I0, I1), dim=1)) F_0_1 = flowOut[:, :2, :, :] F_1_0 = flowOut[:, 2:, :, :] fCoeff = model.getFlowCoeff(validationFrameIndex, device) torch.cuda.empty_cache() F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0 F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0 g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0) g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1) torch.cuda.empty_cache() intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1)) F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 # torch.cuda.empty_cache() g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f) g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f) wCoeff = model.getWarpCoeff(validationFrameIndex, device) torch.cuda.empty_cache() Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) # For tensorboard if (flag): retImg = torchvision.utils.make_grid([ revNormalize(frame0[0]), revNormalize(frameT[0]), revNormalize(Ft_p.cpu()[0]), revNormalize(frame1[0]) ], padding=10) flag = 0 # loss recnLoss = L1_lossFn(Ft_p, IFrame) # torch.cuda.empty_cache() prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame)) warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn( g_I1_F_t_1, IFrame) + L1_lossFn( validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn( validationFlowBackWarp(I1, F_0_1), I0) torch.cuda.empty_cache() loss_smooth_1_0 = torch.mean( torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean( torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :])) loss_smooth_0_1 = torch.mean( torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean( torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :])) loss_smooth = loss_smooth_1_0 + loss_smooth_0_1 # torch.cuda.empty_cache() loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth tloss += loss.item() # psnr MSE_val = MSE_LossFn(Ft_p, IFrame) psnr += (10 * log10(1 / MSE_val.item())) torch.cuda.empty_cache() return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg ### Initialization if args.train_continue: ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT']) flowComp.load_state_dict(dict1['state_dictFC']) optimizer.load_state_dict(dict1.get('state_optimizer', {})) scheduler.load_state_dict(dict1.get('state_scheduler', {})) for param_group in optimizer.param_groups: param_group['lr'] = dict1.get('learningRate', args.init_learning_rate) else: dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1} ### Training import time start = time.time() cLoss = dict1['loss'] valLoss = dict1['valLoss'] valPSNR = dict1['valPSNR'] checkpoint_counter = 0 ### Main training loop optimizer.step() for epoch in range(dict1['epoch'] + 1, args.epochs): print("Epoch: ", epoch) # Append and reset cLoss.append([]) valLoss.append([]) valPSNR.append([]) iLoss = 0 for trainIndex, (trainData, trainFrameIndex) in enumerate(trainloader, 0): ## Getting the input and the target from the training set frame0, frameT, frame1 = trainData I0 = frame0.to(device) I1 = frame1.to(device) IFrame = frameT.to(device) optimizer.zero_grad() # torch.cuda.empty_cache() # Calculate flow between reference frames I0 and I1 flowOut = flowComp(torch.cat((I0, I1), dim=1)) # Extracting flows between I0 and I1 - F_0_1 and F_1_0 F_0_1 = flowOut[:, :2, :, :] F_1_0 = flowOut[:, 2:, :, :] fCoeff = model.getFlowCoeff(trainFrameIndex, device) # Calculate intermediate flows F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0 F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0 # Get intermediate frames from the intermediate flows g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0) g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1) torch.cuda.empty_cache() # Calculate optical flow residuals and visibility maps intrpOut = ArbTimeFlowIntrp( torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1)) # Extract optical flow residuals and visibility maps F_t_0_f = intrpOut[:, :2, :, :] + F_t_0 F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1 V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :]) V_t_1 = 1 - V_t_0 # torch.cuda.empty_cache() # Get intermediate frames from the intermediate flows g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f) g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f) # torch.cuda.empty_cache() wCoeff = model.getWarpCoeff(trainFrameIndex, device) torch.cuda.empty_cache() # Calculate final intermediate frame Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1) # Loss recnLoss = L1_lossFn(Ft_p, IFrame) # torch.cuda.empty_cache() prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame)) # torch.cuda.empty_cache() warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn( g_I1_F_t_1, IFrame) + L1_lossFn( trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn( trainFlowBackWarp(I1, F_0_1), I0) loss_smooth_1_0 = torch.mean( torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:]) ) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :])) loss_smooth_0_1 = torch.mean( torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:]) ) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :])) loss_smooth = loss_smooth_1_0 + loss_smooth_0_1 # torch.cuda.empty_cache() # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4 # since the loss in paper is calculated for input pixels in range 0-255 # and the input to our network is in range 0-1 loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth # Backpropagate loss.backward() optimizer.step() scheduler.step(loss.item()) iLoss += loss.item() torch.cuda.empty_cache() # Validation and progress every `args.progress_iter` iterations if ((trainIndex % args.progress_iter) == args.progress_iter - 1): # Increment scheduler count scheduler.step(iLoss / args.progress_iter) end = time.time() psnr, vLoss, valImg = validate() optimizer.zero_grad() # torch.cuda.empty_cache() valPSNR[epoch].append(psnr) valLoss[epoch].append(vLoss) # Tensorboard itr = trainIndex + epoch * (len(trainloader)) writer.add_scalars( 'Loss', { 'trainLoss': iLoss / args.progress_iter, 'validationLoss': vLoss }, itr) writer.add_scalar('PSNR', psnr, itr) writer.add_image('Validation', valImg, itr) ##### endVal = time.time() print( " Loss: %0.6f Iterations: %4d/%4d TrainExecTime: %0.1f ValLoss:%0.6f ValPSNR: %0.4f ValEvalTime: %0.2f LearningRate: %.1e" % (iLoss / args.progress_iter, trainIndex, len(trainloader), end - start, vLoss, psnr, endVal - end, get_lr(optimizer))) # torch.cuda.empty_cache() cLoss[epoch].append(iLoss / args.progress_iter) iLoss = 0 start = time.time() # Create checkpoint after every `args.checkpoint_epoch` epochs if (epoch % args.checkpoint_epoch) == args.checkpoint_epoch - 1: dict1 = { 'Detail': "End to end Super SloMo.", 'epoch': epoch, 'timestamp': datetime.datetime.now(), 'trainBatchSz': args.train_batch_size, 'validationBatchSz': args.validation_batch_size, 'learningRate': get_lr(optimizer), 'loss': cLoss, 'valLoss': valLoss, 'valPSNR': valPSNR, 'state_dictFC': flowComp.state_dict(), 'state_dictAT': ArbTimeFlowIntrp.state_dict(), 'state_optimizer': optimizer.state_dict(), 'state_scheduler': scheduler.state_dict() } torch.save( dict1, args.checkpoint_dir + "/SuperSloMo" + str(checkpoint_counter) + ".ckpt") checkpoint_counter += 1
class TransformerModel(ObservableData): ''' The abstract class of Transformer. References: - Bahdanau, D., Cho, K., & Bengio, Y. (2014). Neural machine translation by jointly learning to align and translate. arXiv preprint arXiv:1409.0473. - Floridi, L., & Chiriatti, M. (2020). GPT-3: Its nature, scope, limits, and consequences. Minds and Machines, 30(4), 681-694. - Miller, A., Fisch, A., Dodge, J., Karimi, A. H., Bordes, A., & Weston, J. (2016). Key-value memory networks for directly reading documents. arXiv preprint arXiv:1606.03126. - Radford, A., Narasimhan, K., Salimans, T., & Sutskever, I. (2018) Improving Language Understanding by Generative Pre-Training. OpenAI (URL: https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) - Radford, A., Wu, J., Child, R., Luan, D., Amodei, D., & Sutskever, I. (2019). Language models are unsupervised multitask learners. OpenAI blog, 1(8), 9. - Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., & Polosukhin, I. (2017). Attention is all you need. arXiv preprint arXiv:1706.03762. ''' __optimizer = None def get_optimizer(self): return self.__optimizer def set_optimizer(self, value): self.__optimizer = value optimizer = property(get_optimizer, set_optimizer) __embedding_flag = True def get_embedding_flag(self): ''' getter ''' return self.__embedding_flag def set_embedding_flag(self, value): ''' setter ''' self.__embedding_flag = value embedding_flag = property(get_embedding_flag, set_embedding_flag) __embedding_weignt = 1.0 def get_embedding_weignt(self): ''' getter ''' return self.__embedding_weignt def set_embedding_weignt(self, value): ''' setter ''' self.__embedding_weignt = value embedding_weignt = property(get_embedding_weignt, set_embedding_weignt) def learn(self, iteratable_data): ''' Learn samples drawn by `IteratableData.generate_learned_samples()`. Args: iteratable_data: is-a `IteratableData`. ''' if isinstance(iteratable_data, IteratableData) is False: raise TypeError("The type of `iteratable_data` must be `IteratableData`.") self.__loss_list = [] learning_rate = self.learning_rate pre_batch_observed_arr = None pre_test_batch_observed_arr = None try: epoch = 0 iter_n = 0 for batch_observed_arr, batch_target_arr, test_batch_observed_arr, test_batch_target_arr in iteratable_data.generate_learned_samples(): self.epoch = epoch if self.optimizer is not None: self.optimizer.zero_grad() optimizer_setup_flag = True else: optimizer_setup_flag = False # Self-Attention. if len(batch_observed_arr.shape) == 2: batch_observed_arr = torch.unsqueeze(batch_observed_arr, axis=1) pred_arr = self.inference(batch_observed_arr, batch_observed_arr) loss = self.compute_loss( pred_arr, batch_target_arr ) if optimizer_setup_flag is False: # After initilization, restart. self.optimizer.zero_grad() pred_arr = self.inference(batch_observed_arr, batch_observed_arr) loss = self.compute_loss( pred_arr, test_batch_target_arr ) loss.backward() self.optimizer.step() self.regularize() if (iter_n+1) % int(iteratable_data.iter_n / iteratable_data.epochs) == 0: with torch.inference_mode(): if len(test_batch_observed_arr.shape) == 2: test_batch_observed_arr = torch.unsqueeze(test_batch_observed_arr, axis=1) test_pred_arr = self.inference(test_batch_observed_arr, test_batch_observed_arr) test_loss = self.compute_loss( test_pred_arr, test_batch_observed_arr ) _loss = loss.to('cpu').detach().numpy().copy() _test_loss = test_loss.to('cpu').detach().numpy().copy() self.__loss_list.append((_loss, _test_loss)) self.logger.debug("Epochs: " + str(epoch + 1) + " Train loss: " + str(_loss) + " Test loss: " + str(_test_loss)) epoch += 1 iter_n += 1 except KeyboardInterrupt: self.logger.debug("Interrupt.") self.logger.debug("end. ") self.epoch = epoch def compute_loss(self, pred_arr, labeled_arr): ''' Compute loss. Args: pred_arr: `mxnet.ndarray` or `mxnet.symbol`. labeled_arr: `mxnet.ndarray` or `mxnet.symbol`. Returns: loss. ''' return self.__computable_loss(pred_arr, labeled_arr) def regularize(self): ''' Regularization. ''' if len(self.regularizatable_data_list) > 0: params_dict = self.extract_learned_dict() for regularizatable in self.regularizatable_data_list: params_dict = regularizatable.regularize(params_dict) for k, params in params_dict.items(): self.load_state_dict({k: params}, strict=False) def extract_learned_dict(self): ''' Extract (pre-) learned parameters. Returns: `dict` of the parameters. ''' params_dict = {} for k in self.state_dict().keys(): params_dict.setdefault(k, self.state_dict()[k]) return params_dict def embedding(self, observed_arr): ''' Embedding. In default, this method does the positional encoding. Args: observed_arr: `mxnet.ndarray` of observed data points. Returns: `mxnet.ndarray` of embedded data points. ''' batch_size, seq_len, depth_dim = observed_arr.shape self.batch_size = batch_size self.seq_len = seq_len self.depth_dim = depth_dim if self.embedding_flag is False: return observed_arr arr = torch.from_numpy(np.arange(depth_dim)) arr = arr.to(observed_arr.device) depth_arr = torch.tile( torch.unsqueeze( ( arr / 2 ).to(torch.int32) * 2, 0 ), (seq_len, 1) ) depth_arr = depth_arr / depth_dim depth_arr = torch.pow(10000.0, depth_arr).to(torch.float32) arr = torch.from_numpy(np.arange(depth_dim)) arr = arr.to(observed_arr.device) phase_arr = torch.tile( torch.unsqueeze( ( arr % 2 ) * np.pi / 2, 0 ), (seq_len, 1) ) arr = torch.from_numpy(np.arange(seq_len)) arr = arr.to(observed_arr.device) positional_arr = torch.tile( torch.unsqueeze( arr, 1 ), (1, depth_dim) ) sin_arr = torch.sin(positional_arr / depth_arr + phase_arr) positional_encoded_arr = torch.tile( torch.unsqueeze(sin_arr, 0), (batch_size, 1, 1) ) result_arr = observed_arr + (positional_encoded_arr * self.embedding_weignt) return result_arr def save_parameters(self, filename): ''' Save parameters to files. Args: filename: File name. ''' torch.save( { 'epoch': self.epoch, 'model_state_dict': self.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'loss': self.loss_arr, }, filename ) def load_parameters(self, filename, ctx=None, strict=True): ''' Load parameters to files. Args: filename: File name. ctx: Context-manager that changes the selected device. strict: Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`. ''' checkpoint = torch.load(filename) self.load_state_dict(checkpoint['model_state_dict'], strict=strict) try: if self.optimizer is None: if self.optimizer_f is not None: self.optimizer = self.optimizer_f( self.parameters() ) else: self.optimizer = AdamW( self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay ) self.optimizer.load_state_dict( checkpoint['optimizer_state_dict'] ) except ValueError as e: self.logger.debug(e) self.logger.debug("The state of the optimizer in `TransformerModel` was not updated.") self.epoch = checkpoint['epoch'] self.__loss_list = checkpoint['loss'].tolist() if ctx is not None: self.to(ctx) self.__ctx = ctx __epoch = 0 def get_epoch(self): ''' getter for epoch. ''' return self.__epoch def set_epoch(self, value): ''' setter for epoch. ''' self.__epoch = value epoch = property(get_epoch, set_epoch) __loss_arr = np.array([]) def get_loss_arr(self): ''' getter for losses. ''' return self.__loss_arr def set_loss_arr(self, value): self.__loss_arr = value loss_arr = property(get_loss_arr, set_loss_arr) # `bool` that means initialization in this class will be deferred or not. __init_deferred_flag = False def get_init_deferred_flag(self): ''' getter for `bool` that means initialization in this class will be deferred or not. ''' return self.__init_deferred_flag def set_init_deferred_flag(self, value): ''' setter for `bool` that means initialization in this class will be deferred or not. ''' self.__init_deferred_flag = value init_deferred_flag = property(get_init_deferred_flag, set_init_deferred_flag)