def init_trainer_network(self):
        images = self.forward_generator()
        _ = self.forward_discriminator(images, self.condition)
        self.reset_parameters()

        gen_paramas = list(self.generator.parameters())
        dsc_params = list(self.discriminators.parameters())
        self.optim_generator = AdamW(gen_paramas,
                                     lr=self.cfg.gen_lr,
                                     betas=self.cfg.gen_betas,
                                     weight_decay=self.cfg.gen_wd)

        self.optim_discriminator = AdamW(dsc_params,
                                         lr=self.cfg.dsc_lr,
                                         betas=self.cfg.dsc_betas,
                                         weight_decay=self.cfg.dsc_wd)

        # TODO add to yaml params for language optimizer
        feature_params = list(self.features_model.parameters())
        if self.language_model.trainable:
            feature_params = feature_params + list(
                self.language_model.parameters())
        self.optim_feature = AdamW(feature_params,
                                   lr=self.cfg.gen_lr,
                                   betas=self.cfg.gen_betas,
                                   weight_decay=self.cfg.gen_wd)
 def init_trainer_network(self):
     self.reset_parameters()
     self.gen_paramas = list(self.generator.parameters())
     self.dsc_params = list(self.discriminator.parameters())
     self.language_params = list(self.language_model.parameters())
     self.style_params = list(self.style_model.parameters())
     self.content_params = list(self.content_model.parameters())
     self.optim_language = AdamW(self.language_params,
                                 lr=self.cfg.lang_lr,
                                 betas=self.cfg.lang_betas,
                                 weight_decay=self.cfg.lang_wd)
     self.optim_generator = AdamW(self.gen_paramas,
                                  lr=self.cfg.gen_lr,
                                  betas=self.cfg.gen_betas,
                                  weight_decay=self.cfg.gen_wd)
     self.optim_discriminator = AdamW(self.dsc_params,
                                      lr=self.cfg.dsc_lr,
                                      betas=self.cfg.dsc_betas,
                                      weight_decay=self.cfg.dsc_wd)
     self.optim_style = AdamW(self.style_params,
                              lr=self.cfg.lmf_lr,
                              betas=self.cfg.lmf_betas,
                              weight_decay=self.cfg.lmf_wd)
     self.optim_content = AdamW(self.content_params,
                                lr=self.cfg.lmf_lr,
                                betas=self.cfg.lmf_betas,
                                weight_decay=self.cfg.lmf_wd)
Beispiel #3
0
    def configure_optimizers(self):
        opt_cfg = self.cfg['optimizer']
        lr = float(self.cfg['optimizer']['lr'])
        if opt_cfg['name'] == 'AdamW':
            optimizer = AdamW(self.model.parameters(), lr=lr, )
        elif opt_cfg['name'] == 'Adam_GCC':
            optimizer = Adam_GCC(self.model.parameters(), lr=lr)
        elif opt_cfg['name'] == 'AdamW_GCC2':
            optimizer = AdamW_GCC2(self.model.parameters(), lr=lr)

        if self.cfg['scheduler']['type'] == 'none':
            sched = None
        elif self.cfg['scheduler']['type'] == 'CosineAnnealingWarmRestarts':
            T_mult = self.cfg['scheduler']['T_mult']
            T_0 = self.cfg['scheduler']['T_0']
            eta_min = float(self.cfg['scheduler']['eta_min'])
            sched = CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult, eta_min=eta_min, last_epoch=-1)
        elif self.cfg['scheduler']['type'] == 'OneCycleLR':
            max_lr = float(self.cfg['scheduler']['max_lr'])
            steps_per_epoch = cfg['scheduler']['steps_per_epoch']
            epochs = cfg['scheduler']['epochs']
            sched = OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=steps_per_epoch, epochs=epochs)
        else:
            raise Exception('scheduler {} not supported'.format(self.cfg['scheduler']['type']))
        if sched is not None:
            sched = {'scheduler': sched, 'name': format(self.cfg['scheduler']['type'])}

        if sched is not None:
            return [optimizer], [sched]
        else:
            return optimizer

        return optimizer
Beispiel #4
0
def train_end_to_end(args, graph_dataset, model):
    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    lr_scheduler = ReduceLROnPlateau(optimizer, patience=5, cooldown=2, factor=0.5)
    early_stopper = EarlyStopper(args.supervised_model_patience)

    epoch = 0
    start_time = time.time()
    LOG.info('Started training supervised model.')
    while True:
        train_loss = end_to_end_train_epoch(graph_dataset, model, optimizer, args, epoch)
        val_loss, val_acc, val_f1 = end_to_end_val_epoch(graph_dataset, model, args)
        lr_scheduler.step(val_loss)

        epoch_end_time = round(time.time() - start_time, 2)
        LOG.info(f'Epoch {epoch + 1} [{epoch_end_time}s]: '
                 f'Training loss = {train_loss:.4f}, Validation loss = {val_loss:.4f}, '
                 f'Val accuracy = {val_acc:.4f}, Val f1 = {val_f1:.4f}')
        epoch += 1
        if early_stopper.should_stop(model, val_loss) or _STOP_TRAINING:
            break

    # Compute metrics.
    val_y_pred = predict(graph_dataset, model, args, mode='val')
    val_y_pred = raw_output_to_prediction(val_y_pred, args.loss_fn)
    val_y_true = [label for graph in graph_dataset.graphs for label in graph.y[graph.val_mask].cpu().numpy()]
    test_y_pred = predict(graph_dataset, model, args, mode='test')
    test_y_pred = raw_output_to_prediction(test_y_pred, args.loss_fn)
    test_y_true = [label for graph in graph_dataset.graphs for label in graph.y[graph.test_mask].cpu().numpy()]
    log_metrics(graph_dataset, val_y_true, val_y_pred.cpu().numpy(), test_y_true,
                test_y_pred.cpu().numpy(), args.show_detailed_metrics, args.attention_radius, args.max_neighbors)

    # Save trained end-to-end model.
    torch.save(model.state_dict(), os.path.join(args.model_dir, 'node_encoder_end_to_end.pt'))
Beispiel #5
0
def get_optimizer(parameters, optim, learning_rate, lr_decay, amsgrad, weight_decay, warmup_steps):
    if optim == 'sgd':
        optimizer = SGD(parameters, lr=learning_rate, momentum=0.9, weight_decay=weight_decay, nesterov=True)
    else:
        optimizer = AdamW(parameters, lr=learning_rate, betas=(0.9, 0.999), eps=1e-8, amsgrad=amsgrad, weight_decay=weight_decay)
    init_lr = 1e-7
    scheduler = ExponentialScheduler(optimizer, lr_decay, warmup_steps, init_lr)
    return optimizer, scheduler
Beispiel #6
0
    def __init__(
            self, save_path, log_path, d_features, d_meta, max_length, d_classifier, n_classes, threshold=None,
            optimizer=None, **kwargs):
        '''**kwargs: n_layers, n_head, dropout, use_bottleneck, d_bottleneck'''
        '''
            Arguments:
                save_path -- model file path
                log_path -- log file path
                d_features -- how many PMs
                d_meta -- how many facility types
                max_length -- input sequence length
                d_classifier -- classifier hidden unit
                n_classes -- output dim
                threshold -- if not None, n_classes should be 1 (regression).
        '''

        super().__init__(save_path, log_path)
        self.d_output = n_classes
        self.threshold = threshold
        self.max_length = max_length

        # ----------------------------- Model ------------------------------ #

        self.model = Encoder(TimeFacilityEncoding, d_features=d_features, max_seq_length=max_length,
                                       d_meta=d_meta, **kwargs)

        # --------------------------- Classifier --------------------------- #
        self.classifier = LinearClassifier(d_features * max_length, d_classifier, n_classes)

        # ------------------------------ CUDA ------------------------------ #
        # If GPU available, move the graph to GPU(s)
        self.CUDA_AVAILABLE = self.check_cuda()
        if self.CUDA_AVAILABLE:
            device_ids = list(range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids)
            self.classifier = nn.DataParallel(self.classifier, device_ids)
            self.model.to('cuda')
            self.classifier.to('cuda')
            assert (next(self.model.parameters()).is_cuda)
            assert (next(self.classifier.parameters()).is_cuda)
            pass

        else:
            print('CUDA not found or not enabled, use CPU instead')

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters, lr=0.001, betas=(0.9, 0.999), weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000, evaluate_every_nstep=100, print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.count_parameters()
        self.set_logger()
        self.set_summary_writer()
Beispiel #7
0
def train_stage_two(dataset, best_model_file, model_file):
    bestaccuracy = 0.9
    device = 'cudo:0' if torch.cuda.is_available() else 'cpu'
    net = ResNet(BasicBlock, [3, 3, 4, 3]).to(device)  # [2,2,2,2]
    net.train()
    for parameter in net.parameters():
        if len(parameter.shape) > 1:
            torch.nn.init.xavier_uniform_(parameter)
    if isfile(best_model_file):
        net.load_state_dict(torch.load(best_model_file))
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)
    optimizer = AdamW(net.parameters(), lr=0.0001)
    scheduler = CyclicLR(optimizer,
                         0.000001,
                         0.0001,
                         step_size_up=200,
                         mode='triangular2',
                         cycle_momentum=False,
                         last_epoch=-1)
    L1 = torch.nn.L1Loss()
    BCE = torch.nn.BCEWithLogitsLoss()

    for epoch in range(50):
        running_accuracy = []
        for (images, targets) in tqdm(train_loader):
            images, targets = images.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(images)
            clsloss = BCE(outputs[:, 0], targets[:, 0])
            regloss = L1(outputs[:, 1:], targets[:, 1:])
            loss = clsloss + regloss
            cls_preds = np.greater(outputs[:, 0].cpu().detach().numpy(), 0)
            cls_truth = targets[:, 0].cpu().detach().numpy()
            correctness = np.equal(cls_preds, cls_truth).astype(int)
            accuracy = sum(correctness) / 64
            running_accuracy.append(accuracy)
            running_accuracy = running_accuracy[-10:]
            print(' clsloss ' + str(clsloss.cpu().detach().numpy())[:4] +
                  ' regloss ' + str(regloss.cpu().detach().numpy())[:4] +
                  ' accuracy ' + str(np.mean(running_accuracy)),
                  end='\r')
            if np.mean(running_accuracy) > bestaccuracy:
                bestaccuracy = np.mean(running_accuracy)
                torch.save(net.state_dict(), best_model_file)
                # print('totalloss', str(loss.detach().numpy())[:4], 'saved!', end = '\n')
            else:
                pass
                # print('totalloss', str(loss.detach().numpy())[:4]+' ', end = '\n')
            loss.backward()
            optimizer.step()
            scheduler.step(None)
            # if idx%5==0:
            #    print('\n', outputs[0].cpu().detach().numpy(), targets[0].cpu().detach().numpy(), '\n')
            # idx+=1
        torch.save(net.state_dict(), model_file)
        print(epoch)
Beispiel #8
0
def get_optimizer(learning_rate, parameters, betas, eps, amsgrad, step_decay,
                  weight_decay, warmup_steps, init_lr):
    optimizer = AdamW(parameters,
                      lr=learning_rate,
                      betas=betas,
                      eps=eps,
                      amsgrad=amsgrad,
                      weight_decay=weight_decay)
    step_decay = step_decay
    scheduler = ExponentialScheduler(optimizer, step_decay, warmup_steps,
                                     init_lr)
    return optimizer, scheduler
Beispiel #9
0
    def configure_optimizers(self):
        optimizer = AdamW(self.optimizer_params)

        scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0=int(round(self.config.kv['lr_t0'])),
            T_mult=int(round(self.config.kv['lr_f'])))
        scheduler = WarmupScheduler(optimizer,
                                    epochs=self.config.kv['lr_w'],
                                    next_scheduler=scheduler)

        return [optimizer], [scheduler]
def get_optimisers(G: "nn.Module",
                   args: "tupperware") -> "Union[optim, lr_scheduler]":

    g_optimizer = AdamW(G.parameters(),
                        lr=args.learning_rate,
                        betas=(args.beta_1, args.beta_2))

    g_lr_scheduler = CosineAnnealingWarmRestarts(optimizer=g_optimizer,
                                                 T_0=args.T_0,
                                                 T_mult=args.T_mult,
                                                 eta_min=2e-10)

    return g_optimizer, g_lr_scheduler
def split_optimizer(model: nn.Module, cfg: dict):
    param_weight_decay, param_bias, param_other = split_params(model)
    if cfg['optimizer'] == 'Adam':
        optimizer = Adam(param_other, lr=cfg['lr'])
    elif cfg['optimizer'] == 'SGD':
        optimizer = SGD(param_other, lr=cfg['lr'], momentum=cfg['momentum'])
    elif cfg['optimizer'] == "AdamW":
        optimizer = AdamW(param_other, lr=cfg['lr'])
    else:
        raise NotImplementedError("optimizer {:s} is not support!".format(cfg['optimizer']))
    optimizer.add_param_group(
        {'params': param_weight_decay, 'weight_decay': cfg['weight_decay']})  # add pg1 with weight_decay
    optimizer.add_param_group({'params': param_bias})
    return optimizer
Beispiel #12
0
def train(sqlite_file, model_selection_file, model_file):
    """

    Selects the best model/training parameter combination from a model selection and
    trains a final image classifier on the entire annotated part of the image database.

    SQLITE_FILE: An annotated image database (see create-database).
    MODEL_SELECTION_FILE: Results of a model selection created by the "model-selection" tool.
    MODEL_FILE: Store the finally trained image classification model in this file.

    """

    X, class_to_label, label_to_class = load_ground_truth(sqlite_file)

    X['file'] = X['file'].astype(str)
    y = X['class'].astype(int)

    batch_size, decrease_epochs, decrease_factor, epochs, model_name, num_trained, start_lr = load_model_selection(
        model_selection_file)

    model, device, fit_transform, predict_transform, logits_func = \
        load_pretrained_model(model_name, len(label_to_class), num_trained)

    # optimizer = optim.SGD(model_ft.parameters(), lr=start_lr, momentum=momentum)
    optimizer = AdamW(model.parameters(), lr=start_lr)

    sched = lr_scheduler.StepLR(optimizer,
                                step_size=decrease_epochs,
                                gamma=decrease_factor)

    estimator = ImageClassifier(model=model,
                                model_weights=copy.deepcopy(
                                    model.state_dict()),
                                device=device,
                                criterion=nn.CrossEntropyLoss(),
                                optimizer=optimizer,
                                scheduler=sched,
                                fit_transform=fit_transform,
                                predict_transform=predict_transform,
                                batch_size=batch_size,
                                logits_func=logits_func)

    for _ in range(0, epochs):
        estimator.fit(X, y)

    torch.save(model.state_dict(), model_file)
Beispiel #13
0
def get_optimizer(opt, learning_rate, parameters, hyper1, hyper2, eps, rebound,
                  lr_decay, decay_rate, milestone, weight_decay, weight_decay_type,
                  warmup_updates, init_lr, last_lr, num_epochs, world_size):
    if opt == 'sgd':
        optimizer = SGD(parameters, lr=learning_rate, momentum=hyper1, weight_decay=weight_decay, nesterov=True)
        opt = 'momentum=%.1f, ' % (hyper1)
        weight_decay_type = 'L2'
    elif opt == 'radamw':
        optimizer = RAdamW(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps, weight_decay=weight_decay)
        opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps)
        weight_decay_type = 'decoupled'
    elif opt == 'adamw':
        optimizer = AdamW(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps, weight_decay=weight_decay)
        opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps)
        weight_decay_type = 'decoupled'
    elif opt == 'adabelief':
        optimizer = AdaBelief(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps,
                              weight_decay=weight_decay, weight_decay_type=weight_decay_type)
        opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps)
    elif opt == 'apollo':
        optimizer = Apollo(parameters, lr=learning_rate, beta=hyper1, eps=eps, rebound=rebound,
                           warmup=warmup_updates, init_lr=init_lr, weight_decay=weight_decay,
                           weight_decay_type=weight_decay_type)
        opt = 'beta=%.1f, eps=%.1e, rebound=%s, ' % (hyper1, eps, rebound)
    elif opt == 'adahessian':
        optimizer = AdaHessian(parameters, lr=learning_rate, betas=(hyper1, hyper2), eps=eps,
                               warmup=warmup_updates, init_lr=init_lr, weight_decay=weight_decay, num_threads=world_size)
        opt = 'betas=(%.1f, %.3f), eps=%.1e, ' % (hyper1, hyper2, eps)
        weight_decay_type = 'decoupled'
    else:
        raise ValueError('unknown optimizer: {}'.format(opt))

    if lr_decay == 'exp':
        opt = opt + 'lr decay={}, decay rate={:.3f}, '.format(lr_decay, decay_rate)
        scheduler = ExponentialLR(optimizer, decay_rate)
    elif lr_decay == 'milestone':
        opt = opt + 'lr decay={} {}, decay rate={:.3f}, '.format(lr_decay, milestone, decay_rate)
        scheduler = MultiStepLR(optimizer, milestones=milestone, gamma=decay_rate)
    elif lr_decay == 'cosine':
        opt = opt + 'lr decay={}, lr_min={}, '.format(lr_decay, last_lr)
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=last_lr)
    else:
        raise ValueError('unknown lr decay: {}'.format(lr_decay))

    opt += 'warmup={}, init_lr={:.1e}, wd={:.1e} ({})'.format(warmup_updates, init_lr, weight_decay, weight_decay_type)
    return optimizer, scheduler, opt
Beispiel #14
0
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.learning_rate)
        if self.lr_warmup_steps is None:
            return optimizer

        scheduler = InverseSquareRootLR(optimizer, self.lr_warmup_steps)
        return (
            [optimizer],
            [
                {
                    'scheduler': scheduler,
                    'interval': 'step',
                    'frequency': 1,
                    'reduce_on_plateau': False,
                    'monitor': 'label_smoothed_val_loss' if self.optimize_on_smoothed_loss else 'val_loss',
                }
            ]
        )
Beispiel #15
0
    def __init__(self,
                 save_path,
                 log_path,
                 d_features,
                 d_out_list,
                 d_classifier,
                 d_output,
                 threshold=None,
                 optimizer=None,
                 **kwargs):
        '''*args: n_layers, n_head, n_channel, n_vchannel, dropout'''
        super().__init__(save_path, log_path)
        self.d_output = d_output
        self.threshold = threshold

        # ----------------------------- Model ------------------------------ #

        self.model = SelfAttentionFeatureSelection(d_features, d_out_list)

        # --------------------------- Classifier --------------------------- #

        self.classifier = LinearClassifier(d_features, d_classifier, d_output)

        # ------------------------------ CUDA ------------------------------ #
        self.data_parallel()

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(
            self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters,
                                   lr=0.002,
                                   betas=(0.9, 0.999),
                                   weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000,
                                          evaluate_every_nstep=100,
                                          print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.set_logger()
        self.set_summary_writer()
Beispiel #16
0
def train_unsupervised(args, graph_dataset, model):
    optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    lr_scheduler = ReduceLROnPlateau(optimizer, patience=5, cooldown=2, factor=0.5)
    early_stopper = EarlyStopper(args.unsupervised_model_patience)

    start_time = time.time()
    LOG.info("Started training unsupervised model.")
    for epoch in range(args.epochs):
        train_loss, train_count = node_encoder_train_epoch(graph_dataset, model, optimizer, epoch, args)
        val_loss, val_count = node_encoder_val_epoch(graph_dataset, model, epoch, args)
        lr_scheduler.step(val_loss)

        epoch_end_time = round(time.time() - start_time, 2)
        LOG.info(f'Epoch {epoch + 1} [{epoch_end_time}s]: Training loss [over {train_count} nodes] = {train_loss:.4f}, '
                 f'Validation loss [over {val_count} nodes] = {val_loss:.4f}')

        if early_stopper.should_stop(model, val_loss) or _STOP_TRAINING:
            break

    LOG.info("Unsupervised training complete.")
    torch.save(model.state_dict(), os.path.join(args.model_dir, 'node_encoder_unsupervised.pt'))
    def load_parameters(self, filename, ctx=None, strict=True):
        '''
        Load parameters to files.

        Args:
            filename:       File name.
            ctx:            Context-manager that changes the selected device.
            strict:         Whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: `True`.
        '''
        checkpoint = torch.load(filename)
        self.load_state_dict(checkpoint['model_state_dict'], strict=strict)

        try:
            if self.optimizer is None:
                if self.optimizer_f is not None:
                    self.optimizer = self.optimizer_f(
                        self.parameters()
                    )
                else:
                    self.optimizer = AdamW(
                        self.parameters(),
                        lr=self.learning_rate,
                        weight_decay=self.weight_decay
                    )

            self.optimizer.load_state_dict(
                checkpoint['optimizer_state_dict']
            )
        except ValueError as e:
            self.logger.debug(e)
            self.logger.debug("The state of the optimizer in `TransformerModel` was not updated.")

        self.epoch = checkpoint['epoch']
        self.__loss_list = checkpoint['loss'].tolist()
        if ctx is not None:
            self.to(ctx)
            self.__ctx = ctx
Beispiel #18
0
    def _fit(self, modules: nn.ModuleDict, train_dl: DeviceDataLoader,
             valid_dl: DeviceDataLoader):
        r""" Fits \p modules' learners to the training and validation \p DataLoader objects """
        self._configure_fit_vars(modules)

        for mod_name, module in modules.items():
            lr = config.get_learner_val(mod_name,
                                        LearnerParams.Attribute.LEARNING_RATE)
            wd = config.get_learner_val(mod_name,
                                        LearnerParams.Attribute.WEIGHT_DECAY)
            is_lin_ff = config.DATASET.is_synthetic(
            ) and module.module.num_hidden_layers == 0
            if is_lin_ff:
                module.optim = LBFGS(module.parameters(), lr=lr)
            else:
                module.optim = AdamW(module.parameters(),
                                     lr=lr,
                                     weight_decay=wd,
                                     amsgrad=True)
            logging.debug(
                f"{mod_name} Optimizer: {module.optim.__class__.__name__}")

        for ep in range(1, config.NUM_EPOCH + 1):
            # noinspection PyUnresolvedReferences
            for _, module in modules.items():
                module.epoch_start()

            for batch in train_dl:
                for _, module in modules.items():
                    module.process_batch(batch)

            for _, module in modules.items():
                module.calc_valid_loss(valid_dl)
            self._log_epoch(ep, modules)
        self._restore_best_model(modules)
        self.eval()
Beispiel #19
0
 def build_optimizer(self, args: ClassifierArgs, **kwargs):
     no_decay = ['bias', 'LayerNorm.weight']
     optimizer_grouped_parameters = [
         {
             "params": [
                 p for n, p in self.model.named_parameters()
                 if not any(nd in n for nd in no_decay)
             ],
             "weight_decay":
             args.weight_decay,
         },
         {
             "params": [
                 p for n, p in self.model.named_parameters()
                 if any(nd in n for nd in no_decay)
             ],
             "weight_decay":
             0.0
         },
     ]
     optimizer = AdamW(optimizer_grouped_parameters,
                       lr=args.learning_rate,
                       eps=args.adam_epsilon)
     return optimizer
Beispiel #20
0
    def __init__(
        self,
        config: TrainConfig,
        model: NSMCModel,
        train_data_loader: DataLoader,
        dev_data_loader: DataLoader,
        test_data_loader: DataLoader,
        logger: Logger,
        summary_writer: SummaryWriter,
    ):
        self.config = config
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.model = model
        self.model.to(self.device)

        self.train_data_loader = train_data_loader
        self.dev_data_loader = dev_data_loader
        self.test_data_loader = test_data_loader
        self.logger = logger
        self.summary_writer = summary_writer

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = AdamW(model.parameters(), lr=config.learning_rate)

        # total step 계산
        self.steps_per_epoch = len(train_data_loader)
        self.total_steps = self.steps_per_epoch * config.num_epochs
        self.warmup_steps = config.warmup_step_ratio * self.total_steps

        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=self.warmup_steps,
            num_training_steps=self.total_steps)
        self.global_step = 0
Beispiel #21
0
def create_optim(params, args=None, optim_name=None, lr=None):
    if args is None:
        assert optim_name is not None and lr is not None
    else:
        assert optim_name is None and lr is None
        optim_name = args.optim
        lr = args.lr
    if optim_name == 'sgd':
        return SGD(params,
                   lr=lr,
                   momentum=0.9,
                   weight_decay=5e-4,
                   nesterov=True)
    elif optim_name == 'adam':
        return Adam(params, lr=lr, weight_decay=0)
    elif optim_name == 'adamw':
        return AdamW(params, lr=lr, weight_decay=1e-2)
    elif optim_name == 'amsgrad':
        return Adam(params, lr=lr, weight_decay=0, amsgrad=True)
    elif optim_name == 'rmsprop':
        return RMSprop(params, lr=lr, momentum=0.9, weight_decay=0)
    else:
        raise NotImplementedError(
            'Unsupported optimizer_memory: {}'.format(optim_name))
Beispiel #22
0
 def configure_optimizers(self):
     return AdamW(params=self.model.parameters(),
                  lr=2e-5,
                  eps=1e-6,
                  correct_bias=False)
Beispiel #23
0
def main():
    def evaluate_accuracy(data_loader, prefix: str, save: bool = False):
        std_transform.eval()
        model.eval()
        pbar = tqdm(data_loader,
                    desc=prefix,
                    leave=True,
                    total=len(data_loader))
        num_corr = 0
        num_tot = 0
        for idx, batch in enumerate(pbar):
            batch = batch.to(device)
            scores = model(zmuv_transform(std_transform(batch.audio_data)),
                           std_transform.compute_lengths(batch.lengths))
            num_tot += scores.size(0)
            labels = torch.tensor([
                l % SETTINGS.training.num_labels
                for l in batch.labels.tolist()
            ]).to(device)
            num_corr += (scores.max(1)[1] == labels).float().sum().item()
            acc = num_corr / num_tot
            pbar.set_postfix(accuracy=f'{acc:.4}')
        if save and not args.eval:
            writer.add_scalar(f'{prefix}/Metric/acc', acc, epoch_idx)
            ws.increment_model(model, acc / 10)

    apb = ArgumentParserBuilder()
    apb.add_options(
        opt('--model', type=str, choices=model_names(), default='las'),
        opt('--workspace',
            type=str,
            default=str(Path('workspaces') / 'default')),
        opt('--load-weights', action='store_true'),
        opt('--eval', action='store_true'))
    args = apb.parser.parse_args()

    ws = Workspace(Path(args.workspace), delete_existing=not args.eval)
    writer = ws.summary_writer
    set_seed(SETTINGS.training.seed)
    loader = GoogleSpeechCommandsDatasetLoader()
    sr = SETTINGS.audio.sample_rate
    ds_kwargs = dict(sr=sr, mono=SETTINGS.audio.use_mono)
    train_ds, dev_ds, test_ds = loader.load_splits(
        SETTINGS.dataset.dataset_path, **ds_kwargs)

    sr = SETTINGS.audio.sample_rate
    device = torch.device(SETTINGS.training.device)
    std_transform = StandardAudioTransform().to(device).eval()
    zmuv_transform = ZmuvTransform().to(device)
    batchifier = partial(batchify, label_provider=lambda x: x.label)
    truncater = partial(truncate_length,
                        length=int(SETTINGS.training.max_window_size_seconds *
                                   sr))
    train_comp = compose(truncater,
                         TimeshiftTransform().train(),
                         NoiseTransform().train(), batchifier)
    prep_dl = StandardAudioDataLoaderBuilder(train_ds,
                                             collate_fn=batchifier).build(1)
    prep_dl.shuffle = True
    train_dl = StandardAudioDataLoaderBuilder(
        train_ds, collate_fn=train_comp).build(SETTINGS.training.batch_size)
    dev_dl = StandardAudioDataLoaderBuilder(
        dev_ds,
        collate_fn=compose(truncater,
                           batchifier)).build(SETTINGS.training.batch_size)
    test_dl = StandardAudioDataLoaderBuilder(
        test_ds,
        collate_fn=compose(truncater,
                           batchifier)).build(SETTINGS.training.batch_size)

    model = find_model(args.model)().to(device)
    params = list(filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = AdamW(params,
                      SETTINGS.training.learning_rate,
                      weight_decay=SETTINGS.training.weight_decay)
    logging.info(f'{sum(p.numel() for p in params)} parameters')
    criterion = nn.CrossEntropyLoss()

    if (ws.path / 'zmuv.pt.bin').exists():
        zmuv_transform.load_state_dict(torch.load(str(ws.path /
                                                      'zmuv.pt.bin')))
    else:
        for idx, batch in enumerate(tqdm(prep_dl, desc='Constructing ZMUV')):
            batch.to(device)
            zmuv_transform.update(std_transform(batch.audio_data))
            if idx == 2000:  # TODO: quick debugging, remove later
                break
        logging.info(
            dict(zmuv_mean=zmuv_transform.mean, zmuv_std=zmuv_transform.std))
    torch.save(zmuv_transform.state_dict(), str(ws.path / 'zmuv.pt.bin'))

    if args.load_weights:
        ws.load_model(model, best=True)
    if args.eval:
        ws.load_model(model, best=True)
        evaluate_accuracy(dev_dl, 'Dev')
        evaluate_accuracy(test_dl, 'Test')
        return

    ws.write_args(args)
    ws.write_setting(SETTINGS)
    writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params))
    for epoch_idx in trange(SETTINGS.training.num_epochs,
                            position=0,
                            leave=True):
        model.train()
        std_transform.train()
        pbar = tqdm(train_dl,
                    total=len(train_dl),
                    position=1,
                    desc='Training',
                    leave=True)
        for batch in pbar:
            batch.to(device)
            audio_data = zmuv_transform(std_transform(batch.audio_data))
            scores = model(audio_data,
                           std_transform.compute_lengths(batch.lengths))
            optimizer.zero_grad()
            model.zero_grad()
            labels = torch.tensor([
                l % SETTINGS.training.num_labels
                for l in batch.labels.tolist()
            ]).to(device)
            loss = criterion(scores, labels)
            loss.backward()
            optimizer.step()
            pbar.set_postfix(dict(loss=f'{loss.item():.3}'))
            writer.add_scalar('Training/Loss', loss.item(), epoch_idx)

        for group in optimizer.param_groups:
            group['lr'] *= SETTINGS.training.lr_decay
        evaluate_accuracy(dev_dl, 'Dev', save=True)
    evaluate_accuracy(test_dl, 'Test')
Beispiel #24
0
def main():
    def evaluate_engine(dataset: WakeWordDataset,
                        prefix: str,
                        save: bool = False,
                        positive_set: bool = False,
                        write_errors: bool = True,
                        mixer: DatasetMixer = None):
        std_transform.eval()

        if use_frame:
            engine = FrameInferenceEngine(int(SETTINGS.training.max_window_size_seconds * 1000),
                                          int(SETTINGS.training.eval_stride_size_seconds * 1000),
                                          SETTINGS.audio.sample_rate,
                                          model,
                                          zmuv_transform,
                                          negative_label=ctx.negative_label,
                                          coloring=ctx.coloring)
        else:
            engine = SequenceInferenceEngine(SETTINGS.audio.sample_rate,
                                             model,
                                             zmuv_transform,
                                             negative_label=ctx.negative_label,
                                             coloring=ctx.coloring,
                                             blank_idx=ctx.blank_label)
        model.eval()
        conf_matrix = ConfusionMatrix()
        pbar = tqdm(dataset, desc=prefix)
        if write_errors:
            with (ws.path / 'errors.tsv').open('a') as f:
                print(prefix, file=f)
        for idx, ex in enumerate(pbar):
            if mixer is not None:
                ex, = mixer([ex])
            audio_data = ex.audio_data.to(device)
            engine.reset()
            seq_present = engine.infer(audio_data)
            if seq_present != positive_set and write_errors:
                with (ws.path / 'errors.tsv').open('a') as f:
                    f.write(f'{ex.metadata.transcription}\t{int(seq_present)}\t{int(positive_set)}\t{ex.metadata.path}\n')
            conf_matrix.increment(seq_present, positive_set)
            pbar.set_postfix(dict(mcc=f'{conf_matrix.mcc}', c=f'{conf_matrix}'))

        logging.info(f'{conf_matrix}')
        if save and not args.eval:
            writer.add_scalar(f'{prefix}/Metric/tp', conf_matrix.tp, epoch_idx)
            ws.increment_model(model, conf_matrix.tp)
        if args.eval:
            threshold = engine.threshold
            with (ws.path / (str(round(threshold, 2)) + '_results.csv') ).open('a') as f:
                f.write(f'{prefix},{threshold},{conf_matrix.tp},{conf_matrix.tn},{conf_matrix.fp},{conf_matrix.fn}\n')

    def do_evaluate():
        evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True)
        evaluate_engine(ww_dev_neg_ds, 'Dev negative', positive_set=False)
        if SETTINGS.training.use_noise_dataset:
            evaluate_engine(ww_dev_pos_ds, 'Dev noisy positive', positive_set=True, mixer=dev_mixer)
            evaluate_engine(ww_dev_neg_ds, 'Dev noisy negative', positive_set=False, mixer=dev_mixer)
        evaluate_engine(ww_test_pos_ds, 'Test positive', positive_set=True)
        evaluate_engine(ww_test_neg_ds, 'Test negative', positive_set=False)
        if SETTINGS.training.use_noise_dataset:
            evaluate_engine(ww_test_pos_ds, 'Test noisy positive', positive_set=True, mixer=test_mixer)
            evaluate_engine(ww_test_neg_ds, 'Test noisy negative', positive_set=False, mixer=test_mixer)

    apb = ArgumentParserBuilder()
    apb.add_options(opt('--model', type=str, choices=RegisteredModel.registered_names(), default='las'),
                    opt('--workspace', type=str, default=str(Path('workspaces') / 'default')),
                    opt('--load-weights', action='store_true'),
                    opt('--load-last', action='store_true'),
                    opt('--no-dev-per-epoch', action='store_false', dest='dev_per_epoch'),
                    opt('--dataset-paths', '-i', type=str, nargs='+', default=[SETTINGS.dataset.dataset_path]),
                    opt('--eval', action='store_true'))
    args = apb.parser.parse_args()

    use_frame = SETTINGS.training.objective == 'frame'
    ctx = InferenceContext(SETTINGS.training.vocab, token_type=SETTINGS.training.token_type, use_blank=not use_frame)
    if use_frame:
        batchifier = WakeWordFrameBatchifier(ctx.negative_label,
                                             window_size_ms=int(SETTINGS.training.max_window_size_seconds * 1000))
        criterion = nn.CrossEntropyLoss()
    else:
        tokenizer = WakeWordTokenizer(ctx.vocab, ignore_oov=False)
        batchifier = AudioSequenceBatchifier(ctx.negative_label, tokenizer)
        criterion = nn.CTCLoss(ctx.blank_label)

    ws = Workspace(Path(args.workspace), delete_existing=not args.eval)
    writer = ws.summary_writer
    set_seed(SETTINGS.training.seed)
    loader = WakeWordDatasetLoader()
    ds_kwargs = dict(sr=SETTINGS.audio.sample_rate, mono=SETTINGS.audio.use_mono, frame_labeler=ctx.labeler)

    ww_train_ds, ww_dev_ds, ww_test_ds = WakeWordDataset(metadata_list=[], set_type=DatasetType.TRAINING, **ds_kwargs), \
                                         WakeWordDataset(metadata_list=[], set_type=DatasetType.DEV, **ds_kwargs), \
                                         WakeWordDataset(metadata_list=[], set_type=DatasetType.TEST, **ds_kwargs)
    for ds_path in args.dataset_paths:
        ds_path = Path(ds_path)
        train_ds, dev_ds, test_ds = loader.load_splits(ds_path, **ds_kwargs)
        ww_train_ds.extend(train_ds)
        ww_dev_ds.extend(dev_ds)
        ww_test_ds.extend(test_ds)
    print_stats(f'Wake word dataset', ww_train_ds, ww_dev_ds, ww_test_ds)

    ww_dev_pos_ds = ww_dev_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True)
    ww_dev_neg_ds = ww_dev_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True)
    ww_test_pos_ds = ww_test_ds.filter(lambda x: ctx.searcher.search(x.transcription), clone=True)
    ww_test_neg_ds = ww_test_ds.filter(lambda x: not ctx.searcher.search(x.transcription), clone=True)

    print_stats(f'Dev dataset', ww_dev_pos_ds, ww_dev_neg_ds)
    print_stats(f'Test dataset', ww_test_pos_ds, ww_test_neg_ds)
    device = torch.device(SETTINGS.training.device)
    std_transform = StandardAudioTransform().to(device).eval()
    zmuv_transform = ZmuvTransform().to(device)

    train_comp = (NoiseTransform().train(), batchifier)

    if SETTINGS.training.use_noise_dataset:
        noise_ds = RecursiveNoiseDatasetLoader().load(Path(SETTINGS.raw_dataset.noise_dataset_path),
                                                      sr=SETTINGS.audio.sample_rate,
                                                      mono=SETTINGS.audio.use_mono)
        logging.info(f'Loaded {len(noise_ds.metadata_list)} noise files.')
        noise_ds_train, noise_ds_dev = noise_ds.split(Sha256Splitter(80))
        noise_ds_dev, noise_ds_test = noise_ds_dev.split(Sha256Splitter(50))
        train_comp = (DatasetMixer(noise_ds_train).train(),) + train_comp
        dev_mixer = DatasetMixer(noise_ds_dev, seed=0, do_replace=False)
        test_mixer = DatasetMixer(noise_ds_test, seed=0, do_replace=False)
    train_comp = compose(*train_comp)

    prep_dl = StandardAudioDataLoaderBuilder(ww_train_ds, collate_fn=batchify).build(1)
    prep_dl.shuffle = True
    train_dl = StandardAudioDataLoaderBuilder(ww_train_ds, collate_fn=train_comp).build(SETTINGS.training.batch_size)

    model = RegisteredModel.find_registered_class(args.model)(ctx.num_labels).to(device).streaming()
    if SETTINGS.training.convert_static:
        model = ConvertedStaticModel(model, 40, 10)
    params = list(filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = AdamW(params, SETTINGS.training.learning_rate, weight_decay=SETTINGS.training.weight_decay)
    logging.info(f'{sum(p.numel() for p in params)} parameters')

    if (ws.path / 'zmuv.pt.bin').exists():
        zmuv_transform.load_state_dict(torch.load(str(ws.path / 'zmuv.pt.bin')))
    else:
        for idx, batch in enumerate(tqdm(prep_dl, desc='Constructing ZMUV')):
            batch.to(device)
            zmuv_transform.update(std_transform(batch.audio_data))
            if idx == 2000:  # TODO: quick debugging, remove later
                break
        logging.info(dict(zmuv_mean=zmuv_transform.mean, zmuv_std=zmuv_transform.std))
    torch.save(zmuv_transform.state_dict(), str(ws.path / 'zmuv.pt.bin'))

    if args.load_weights:
        ws.load_model(model, best=not args.load_last)
    if args.eval:
        ws.load_model(model, best=not args.load_last)
        do_evaluate()
        return

    ws.write_args(args)
    ws.write_settings(SETTINGS)
    writer.add_scalar('Meta/Parameters', sum(p.numel() for p in params))
    for epoch_idx in trange(SETTINGS.training.num_epochs, position=0, leave=True):
        model.train()
        std_transform.train()
        model.streaming_state = None
        pbar = tqdm(train_dl,
                    total=len(train_dl),
                    position=1,
                    desc='Training',
                    leave=True)
        total_loss = torch.Tensor([0.0]).to(device)
        for batch in pbar:
            batch.to(device)
            if use_frame:
                scores = model(zmuv_transform(std_transform(batch.audio_data)),
                               std_transform.compute_lengths(batch.lengths))
                loss = criterion(scores, batch.labels)
            else:
                lengths = std_transform.compute_lengths(batch.audio_lengths)
                scores = model(zmuv_transform(std_transform(batch.audio_data)), lengths)
                scores = F.log_softmax(scores, -1)  # [num_frames x batch_size x num_labels]
                lengths = torch.tensor([model.compute_length(x.item()) for x in lengths]).to(device)
                loss = criterion(scores, batch.labels, lengths, batch.label_lengths)
            optimizer.zero_grad()
            model.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix(dict(loss=f'{loss.item():.3}'))
            with torch.no_grad():
                total_loss += loss

        for group in optimizer.param_groups:
            group['lr'] *= SETTINGS.training.lr_decay

        mean = total_loss / len(train_dl)
        writer.add_scalar('Training/Loss', mean.item(), epoch_idx)
        writer.add_scalar('Training/LearningRate', group['lr'], epoch_idx)

        if args.dev_per_epoch:
            evaluate_engine(ww_dev_pos_ds, 'Dev positive', positive_set=True, save=True, write_errors=False)

    do_evaluate()
Beispiel #25
0
    if args.encoder == 'bert':
        tokenizer = get_bert_tokenizer(
            args.pretrain_checkpoint,
            add_tokens=['[EOT]']
        )
        two_tower_model = build_biencoder_model(
            model_tokenizer=tokenizer,
            args=args, aggregation='cls'
        )

    else:
        raise Exception

    optimizer = AdamW(
        [{'params': two_tower_model.parameters(), 'initial_lr': args.lr}],
        lr=args.lr
    )
    if args.best_acc > 0:
        two_tower_model, likelihood_criterion = get_saved_model_and_optimizer(
            model=two_tower_model,
            optimizer=optimizer,
            checkpoint_path=os.path.join(
                args.model_checkpoint,
                "%s_acc_%.5f" % (args.task_name, args.best_acc)
            ))
        # two_tower_model.load_state_dict(torch.load(os.path.join(
        #     args.model_checkpoint, "%s_acc_%.5f" % (args.task_name,
        #                                             args.best_acc)
        # )))

    if args.distributed and args.device == 'cuda':
Beispiel #26
0
    def __init__(self,
                 save_path,
                 log_path,
                 n_depth,
                 d_features,
                 d_classifier,
                 d_output,
                 threshold=None,
                 stack='ShuffleSelfAttention',
                 expansion_layer='ChannelWiseConvExpansion',
                 mode='1d',
                 optimizer=None,
                 **kwargs):
        '''*args: n_layers, n_head, n_channel, n_vchannel, dropout, use_bottleneck, d_bottleneck'''
        '''
            Arguments:
                mode:   1d:         1d output
                        2d:         2d output
                        residual:   residual output
                        dense:      dense net
        
        '''

        super().__init__(save_path, log_path)
        self.d_output = d_output
        self.threshold = threshold

        # ----------------------------- Model ------------------------------ #
        stack_dict = {
            'ReactionAttention': ReactionAttentionStack,
            'SelfAttention': SelfAttentionStack,
            'Alternate': AlternateStack,
            'Parallel': ParallelStack,
            'ShuffleSelfAttention': ShuffleSelfAttentionStack,
            'ShuffleSelfAttentionStackV2': ShuffleSelfAttentionStackV2,
        }
        expansion_dict = {
            'LinearExpansion': LinearExpansion,
            'ReduceParamLinearExpansion': ReduceParamLinearExpansion,
            'ConvExpansion': ConvExpansion,
            'LinearConvExpansion': LinearConvExpansion,
            'ShuffleConvExpansion': ShuffleConvExpansion,
            'ChannelWiseConvExpansion': ChannelWiseConvExpansion,
        }

        self.model = stack_dict[stack](expansion_dict[expansion_layer],
                                       n_depth=n_depth,
                                       d_features=d_features,
                                       mode=mode,
                                       **kwargs)

        # --------------------------- Classifier --------------------------- #
        if mode == '1d':
            self.classifier = LinearClassifier(d_features, d_classifier,
                                               d_output)
        elif mode == '2d':
            self.classifier = LinearClassifier(n_depth * d_features,
                                               d_classifier, d_output)
        else:
            self.classifier = None

        # ------------------------------ CUDA ------------------------------ #
        # If GPU available, move the graph to GPU(s)

        self.CUDA_AVAILABLE = self.check_cuda()
        if self.CUDA_AVAILABLE:
            # self.model.cuda()
            # self.classifier.cuda()

            device_ids = list(range(torch.cuda.device_count()))
            self.model = nn.DataParallel(self.model, device_ids)
            self.classifier = nn.DataParallel(self.classifier, device_ids)
            self.model.to('cuda')
            self.classifier.to('cuda')
            assert (next(self.model.parameters()).is_cuda)
            assert (next(self.classifier.parameters()).is_cuda)
            pass

        else:
            print('CUDA not found or not enabled, use CPU instead')

        # ---------------------------- Optimizer --------------------------- #
        self.parameters = list(self.model.parameters()) + list(
            self.classifier.parameters())
        if optimizer == None:
            self.optimizer = AdamW(self.parameters,
                                   lr=0.002,
                                   betas=(0.9, 0.999),
                                   weight_decay=0.001)

        # ------------------------ training control ------------------------ #
        self.controller = TrainingControl(max_step=100000,
                                          evaluate_every_nstep=100,
                                          print_every_nstep=10)
        self.early_stopping = EarlyStopping(patience=50)

        # --------------------- logging and tensorboard -------------------- #
        self.set_logger()
        self.set_summary_writer()
Beispiel #27
0
def train(model_path: Optional[str] = None,
          data_path: Optional[str] = None,
          version: Optional[str] = None,
          save_to_dir: Optional[str] = None):

    train_dl, val_dl, test_dl, nbatches, parser, version = init(
        data_path, version=version, save_to_dir=save_to_dir)
    core_opt = AdamW(parser.parameters(),
                     lr=0.,
                     betas=(0.9, 0.98),
                     eps=1e-12,
                     weight_decay=1e-02)
    dec_schedule = make_cosine_schedule(max_lr=5e-04,
                                        warmup_steps=nbatches // 4,
                                        decay_over=decoder_epochs * nbatches)
    mutual_schedule = make_cosine_schedule_with_linear_restarts(
        max_lr=1e-04,
        warmup_steps=nbatches // 4,
        decay_over=mutual_epochs * nbatches,
        triangle_decay=40 * nbatches)
    st_loss = NormCrossEntropy(sep_id=parser.atom_tokenizer.sep_token_id,
                               ignore_index=parser.atom_tokenizer.pad_token_id)
    sh_loss = SinkhornLoss()

    if model_path is not None:
        print('Loading checkpoint...')
        step_num, opt_dict, init_epoch = load_model(parser, model_path)
        if init_epoch < decoder_epochs:
            schedule = dec_schedule
        else:
            schedule = mutual_schedule

        opt = Scheduler(core_opt, schedule)
        opt.step_num = step_num
        opt.lr = opt.schedule(opt.step_num)
        opt.opt.load_state_dict(opt_dict)
        del opt_dict
    else:
        opt = Scheduler(core_opt, dec_schedule)
        init_epoch = 0

    if save_to_dir is None:
        save_to_dir = './stored_models'

    for e in range(init_epoch, decoder_epochs + mutual_epochs):
        validate = e % 5 == 0 and e != init_epoch
        save = e % 5 == 0 and e != init_epoch
        linking_weight = 0.33

        if save:
            print('\tSaving')
            torch.save(
                {
                    'model_state_dict': parser.state_dict(),
                    'opt_state_dict': opt.opt.state_dict(),
                    'step': opt.step_num,
                    'epoch': e
                }, f'{save_to_dir}/{version}/{e}.model')

        if e < decoder_epochs:
            with open(f'{save_to_dir}/{version}/log.txt', 'a') as stream:
                logprint('=' * 64, [stream])
                logprint(f'Pre-epoch {e}', [stream])
                logprint(' ' * 50 + f'LR: {opt.lr}\t({opt.step_num})',
                         [stream])
                supertagging_loss, linking_loss = parser.pretrain_decoder_epoch(
                    train_dl, st_loss, sh_loss, opt, linking_weight)
                logprint(f' Supertagging Loss:\t\t{supertagging_loss:5.2f}',
                         [stream])
                logprint(f' Linking Loss:\t\t\t{linking_loss:5.2f}', [stream])
                if validate:
                    with open(f'{save_to_dir}/{version}/val_log.txt',
                              'a') as valstream:
                        logprint(f'Epoch {e}', [valstream])
                        logprint('-' * 64, [stream, valstream])
                        sentence_ac, atom_ac, link_ac = parser.preval_epoch(
                            val_dl)
                        logprint(
                            f' Sentence Accuracy:\t\t{(sentence_ac * 100):6.2f}',
                            [stream, valstream])
                        logprint(f' Atom Accuracy:\t\t{(atom_ac * 100):6.2f}',
                                 [stream, valstream])
                        logprint(f' Link Accuracy:\t\t{(link_ac * 100):6.2f}',
                                 [stream, valstream])
                continue
        elif e == decoder_epochs:
            opt = Scheduler(opt.opt, mutual_schedule)

        with open(f'{save_to_dir}/{version}/log.txt', 'a') as stream:
            logprint('=' * 64, [stream])
            logprint(f'Epoch {e}', [stream])
            logprint(' ' * 50 + f'LR: {opt.lr}\t({opt.step_num})', [stream])
            logprint(' ' * 50 + f'LW: {linking_weight}', [stream])
            logprint('-' * 64, [stream])
            supertagging_loss, linking_loss = parser.train_epoch(
                train_dl, st_loss, sh_loss, opt, linking_weight)
            logprint(f' Supertagging Loss:\t\t{supertagging_loss:5.2f}',
                     [stream])
            logprint(f' Linking Loss:\t\t\t{linking_loss:5.2f}', [stream])
            if validate:
                with open(f'{save_to_dir}/{version}/val_log.txt',
                          'a') as valstream:
                    logprint(f'Epoch {e}', [valstream])
                    logprint('-' * 64, [stream, valstream])
                    sentence_ac, atom_ac, link_ac = parser.eval_epoch(
                        val_dl, link=True)
                    logprint(
                        f' Sentence Accuracy:\t\t{(sentence_ac * 100):6.2f}',
                        [stream, valstream])
                    logprint(f' Atom Accuracy:\t\t{(atom_ac * 100):6.2f}',
                             [stream, valstream])
                    logprint(f' Link Accuracy:\t\t{(link_ac * 100):6.2f}',
                             [stream, valstream])
            logprint('\n', [stream])
Beispiel #28
0
        sampler=data.RandomSampler(train_dataset),  # =sampler,
        num_workers=4,
        collate_fn=pad)
    dev_iter = data.DataLoader(dataset=dev_dataset,
                               batch_size=hp.batch_size,
                               shuffle=False,
                               num_workers=4,
                               collate_fn=pad)
    test_iter = data.DataLoader(dataset=test_dataset,
                                batch_size=hp.batch_size,
                                shuffle=False,
                                num_workers=4,
                                collate_fn=pad)

    # optimizer = BertAdam(model.parameters(), lr=hp.lr)
    optimizer = AdamW(model.parameters(), lr=hp.lr, weight_decay=0.01)
    total_train_step = hp.n_epochs * len(train_iter)
    scheduler = transformers.get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_train_step,
    )

    criterion = nn.CrossEntropyLoss(ignore_index=0)
    mode = hp.model
    savedir = "mpqa_eval_" + mode
    os.makedirs(savedir, exist_ok=True)

    writer = SummaryWriter()
    highest = 0
    for epoch in range(1, hp.n_epochs + 1):
Beispiel #29
0
def train_stage_three(dataset, best_model_file, model_file):
    bestaccuracy = 0.9
    device = 'cudo:0' if torch.cuda.is_available() else 'cpu'
    net = MyUNet(3, device).to(device)
    net.train()
    for parameter in net.parameters():
        if len(parameter.shape) > 1:
            torch.nn.init.xavier_uniform_(parameter)
    if isfile(best_model_file):
        net.load_state_dict(torch.load(best_model_file))
    train_loader = DataLoader(dataset, batch_size=4, shuffle=True)
    optimizer = AdamW(net.parameters(), lr=0.0001)
    scheduler = CyclicLR(optimizer,
                         0.000001,
                         0.0001,
                         step_size_up=200,
                         mode='triangular2',
                         cycle_momentum=False,
                         last_epoch=-1)
    L1 = torch.nn.L1Loss(size_average=False)

    for epoch in range(50):
        for (images, targets, out_masks) in tqdm(train_loader):
            images = images.to(device)
            targets = targets.to(device)
            out_masks = out_masks.to(device)
            optimizer.zero_grad()
            outputs = net(images)
            loss = L1(outputs * out_masks, targets * out_masks) / 4
            outputs = (outputs * out_masks).cpu().detach().numpy()
            targets = (targets * out_masks).cpu().detach().numpy()
            if np.mean(np.linalg.norm(targets * 100, axis=1)) > 0:
                truth_norm = np.linalg.norm(targets * 100, axis=1).flatten()
                error_norm = np.linalg.norm(outputs * 100 - targets * 100,
                                            axis=1).flatten()
                truth_norm, error_norm = truth_norm[
                    truth_norm > 0], error_norm[error_norm > 0]
                accuracy = sum(
                    (error_norm / truth_norm) < 0.1) / len(error_norm)
                print('mean error',
                      np.mean(error_norm / truth_norm),
                      'accuracy',
                      accuracy,
                      end='\t')
            else:
                accuracy = 0.0
            print('L1loss', loss.cpu().detach().numpy(), end='\r')
            if accuracy > bestaccuracy:
                bestaccuracy = accuracy
                torch.save(net.state_dict(), best_model_file)
            else:
                pass
                # print('totalloss', str(loss.detach().numpy())[:4]+' ', end = '\n')
            loss.backward()
            optimizer.step()
            scheduler.step(None)
            # if idx%5==0:
            #    print('\n', outputs[0].cpu().detach().numpy(), targets[0].cpu().detach().numpy(), '\n')
            # idx+=1
        torch.save(net.state_dict(), model_file)
        print(epoch)
Beispiel #30
0
def train():
    global writer
    # For parsing commandline arguments
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset_root",
        type=str,
        required=True,
        help='path to dataset folder containing train-test-validation folders')
    parser.add_argument("--checkpoint_dir",
                        type=str,
                        required=True,
                        help='path to folder for saving checkpoints')
    parser.add_argument("--checkpoint",
                        type=str,
                        help='path of checkpoint for pretrained model')
    parser.add_argument(
        "--train_continue",
        type=bool,
        default=False,
        help=
        'If resuming from checkpoint, set to True and set `checkpoint` path. Default: False.'
    )
    parser.add_argument("--epochs",
                        type=int,
                        default=200,
                        help='number of epochs to train. Default: 200.')
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=3,
                        help='batch size for training. Default: 6.')
    parser.add_argument("--validation_batch_size",
                        type=int,
                        default=6,
                        help='batch size for validation. Default: 10.')
    parser.add_argument("--init_learning_rate",
                        type=float,
                        default=0.0001,
                        help='set initial learning rate. Default: 0.0001.')
    parser.add_argument(
        "--milestones",
        type=list,
        default=[25, 50],
        help=
        'UNUSED NOW: Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]'
    )
    parser.add_argument(
        "--progress_iter",
        type=int,
        default=200,
        help=
        'frequency of reporting progress and validation. N: after every N iterations. Default: 100.'
    )
    parser.add_argument(
        "--checkpoint_epoch",
        type=int,
        default=5,
        help=
        'checkpoint saving frequency. N: after every N epochs. Each checkpoint is roughly of size 151 MB.Default: 5.'
    )
    args = parser.parse_args()

    ##[TensorboardX](https://github.com/lanpa/tensorboardX)
    ### For visualizing loss and interpolated frames

    ###Initialize flow computation and arbitrary-time flow interpolation CNNs.

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)
    flowComp = model.UNet(6, 4)
    flowComp.to(device)
    ArbTimeFlowIntrp = model.UNet(20, 5)
    ArbTimeFlowIntrp.to(device)

    ###Initialze backward warpers for train and validation datasets

    train_W_dim = 352
    train_H_dim = 352

    trainFlowBackWarp = model.backWarp(train_W_dim, train_H_dim, device)
    trainFlowBackWarp = trainFlowBackWarp.to(device)
    validationFlowBackWarp = model.backWarp(train_W_dim * 2, train_H_dim,
                                            device)
    validationFlowBackWarp = validationFlowBackWarp.to(device)

    ###Load Datasets

    # Channel wise mean calculated on custom training dataset
    # mean = [0.43702903766008444, 0.43715053433990597, 0.40436416782660994]
    mean = [0.5] * 3
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train',
                                     randomCropSize=(train_W_dim, train_H_dim),
                                     transform=transform,
                                     train=True)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=args.train_batch_size,
                                              shuffle=True,
                                              num_workers=2,
                                              pin_memory=True)

    validationset = dataloader.SuperSloMo(
        root=args.dataset_root + '/validation',
        transform=transform,
        randomCropSize=(2 * train_W_dim, train_H_dim),
        train=False)
    validationloader = torch.utils.data.DataLoader(
        validationset,
        batch_size=args.validation_batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True)

    print(trainset, validationset)

    ###Create transform to display image from tensor

    negmean = [x * -1 for x in mean]
    revNormalize = transforms.Normalize(mean=negmean, std=std)
    TP = transforms.Compose([revNormalize, transforms.ToPILImage()])

    ###Utils

    def get_lr(optimizer):
        for param_group in optimizer.param_groups:
            return param_group['lr']

    ###Loss and Optimizer

    L1_lossFn = nn.L1Loss()
    MSE_LossFn = nn.MSELoss()

    if args.train_continue:
        dict1 = torch.load(args.checkpoint)
        last_epoch = dict1['epoch'] * len(trainloader)
    else:
        last_epoch = -1

    params = list(ArbTimeFlowIntrp.parameters()) + list(flowComp.parameters())

    optimizer = AdamW(params, lr=args.init_learning_rate, amsgrad=True)
    # optimizer = optim.SGD(params, lr=args.init_learning_rate, momentum=0.9, nesterov=True)

    # scheduler to decrease learning rate by a factor of 10 at milestones.
    # Patience suggested value:
    # patience = number of item in train dataset / train_batch_size * (Number of epochs patience)
    # It does say epoch, but in this case, the number of progress iterations is what's really being worked with.
    # As such, each epoch will be given by the above formula (roughly, if using a rough dataset count)
    # If the model seems to equalize fast, reduce the number of epochs accordingly.

    # scheduler = optim.lr_scheduler.CyclicLR(optimizer,
    #                                         base_lr=1e-8,
    #                                         max_lr=9.0e-3,
    #                                         step_size_up=3500,
    #                                         mode='triangular2',
    #                                         cycle_momentum=False,
    #                                         last_epoch=last_epoch)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=len(trainloader) * 3,
        cooldown=len(trainloader) * 2,
        verbose=True,
        min_lr=1e-8)

    # Changed to use this to ensure a more adaptive model.
    # The changed model used here seems to converge or plateau faster with more rapid swings over time.
    # As such letting the model deal with stagnation more proactively than at a set stage seems more useful.

    ###Initializing VGG16 model for perceptual loss

    vgg16 = torchvision.models.vgg16(pretrained=True)
    vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22])
    vgg16_conv_4_3.to(device)

    for param in vgg16_conv_4_3.parameters():
        param.requires_grad = False

    # Validation function

    def validate():
        # For details see training.
        psnr = 0
        tloss = 0
        flag = 1
        with torch.no_grad():
            for validationIndex, (validationData,
                                  validationFrameIndex) in enumerate(
                                      validationloader, 0):
                frame0, frameT, frame1 = validationData

                I0 = frame0.to(device)
                I1 = frame1.to(device)
                IFrame = frameT.to(device)

                torch.cuda.empty_cache()
                flowOut = flowComp(torch.cat((I0, I1), dim=1))
                F_0_1 = flowOut[:, :2, :, :]
                F_1_0 = flowOut[:, 2:, :, :]

                fCoeff = model.getFlowCoeff(validationFrameIndex, device)
                torch.cuda.empty_cache()
                F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
                F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

                g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
                g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)
                torch.cuda.empty_cache()
                intrpOut = ArbTimeFlowIntrp(
                    torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                               g_I0_F_t_0),
                              dim=1))

                F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
                F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
                V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
                V_t_1 = 1 - V_t_0
                # torch.cuda.empty_cache()
                g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
                g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)

                wCoeff = model.getWarpCoeff(validationFrameIndex, device)
                torch.cuda.empty_cache()
                Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                        g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

                # For tensorboard
                if (flag):
                    retImg = torchvision.utils.make_grid([
                        revNormalize(frame0[0]),
                        revNormalize(frameT[0]),
                        revNormalize(Ft_p.cpu()[0]),
                        revNormalize(frame1[0])
                    ],
                                                         padding=10)
                    flag = 0

                # loss
                recnLoss = L1_lossFn(Ft_p, IFrame)
                # torch.cuda.empty_cache()
                prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p),
                                      vgg16_conv_4_3(IFrame))

                warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                    g_I1_F_t_1, IFrame) + L1_lossFn(
                        validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                            validationFlowBackWarp(I1, F_0_1), I0)
                torch.cuda.empty_cache()
                loss_smooth_1_0 = torch.mean(
                    torch.abs(F_1_0[:, :, :, :-1] -
                              F_1_0[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_1_0[:, :, :-1, :] -
                                            F_1_0[:, :, 1:, :]))
                loss_smooth_0_1 = torch.mean(
                    torch.abs(F_0_1[:, :, :, :-1] -
                              F_0_1[:, :, :, 1:])) + torch.mean(
                                  torch.abs(F_0_1[:, :, :-1, :] -
                                            F_0_1[:, :, 1:, :]))
                loss_smooth = loss_smooth_1_0 + loss_smooth_0_1

                # torch.cuda.empty_cache()
                loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth
                tloss += loss.item()

                # psnr
                MSE_val = MSE_LossFn(Ft_p, IFrame)
                psnr += (10 * log10(1 / MSE_val.item()))
                torch.cuda.empty_cache()

        return (psnr / len(validationloader)), (tloss /
                                                len(validationloader)), retImg

    ### Initialization

    if args.train_continue:
        ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
        flowComp.load_state_dict(dict1['state_dictFC'])

        optimizer.load_state_dict(dict1.get('state_optimizer', {}))
        scheduler.load_state_dict(dict1.get('state_scheduler', {}))

        for param_group in optimizer.param_groups:
            param_group['lr'] = dict1.get('learningRate',
                                          args.init_learning_rate)

    else:
        dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}

    ### Training

    import time

    start = time.time()
    cLoss = dict1['loss']
    valLoss = dict1['valLoss']
    valPSNR = dict1['valPSNR']
    checkpoint_counter = 0

    ### Main training loop

    optimizer.step()

    for epoch in range(dict1['epoch'] + 1, args.epochs):
        print("Epoch: ", epoch)

        # Append and reset
        cLoss.append([])
        valLoss.append([])
        valPSNR.append([])
        iLoss = 0

        for trainIndex, (trainData,
                         trainFrameIndex) in enumerate(trainloader, 0):

            ## Getting the input and the target from the training set
            frame0, frameT, frame1 = trainData

            I0 = frame0.to(device)
            I1 = frame1.to(device)
            IFrame = frameT.to(device)
            optimizer.zero_grad()
            # torch.cuda.empty_cache()
            # Calculate flow between reference frames I0 and I1
            flowOut = flowComp(torch.cat((I0, I1), dim=1))

            # Extracting flows between I0 and I1 - F_0_1 and F_1_0
            F_0_1 = flowOut[:, :2, :, :]
            F_1_0 = flowOut[:, 2:, :, :]

            fCoeff = model.getFlowCoeff(trainFrameIndex, device)

            # Calculate intermediate flows
            F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
            F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0

            # Get intermediate frames from the intermediate flows
            g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)
            g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1)
            torch.cuda.empty_cache()
            # Calculate optical flow residuals and visibility maps
            intrpOut = ArbTimeFlowIntrp(
                torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1,
                           g_I0_F_t_0),
                          dim=1))

            # Extract optical flow residuals and visibility maps
            F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
            F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
            V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
            V_t_1 = 1 - V_t_0
            # torch.cuda.empty_cache()
            # Get intermediate frames from the intermediate flows
            g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f)
            g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f)
            # torch.cuda.empty_cache()
            wCoeff = model.getWarpCoeff(trainFrameIndex, device)
            torch.cuda.empty_cache()
            # Calculate final intermediate frame
            Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 *
                    g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)

            # Loss
            recnLoss = L1_lossFn(Ft_p, IFrame)
            # torch.cuda.empty_cache()

            prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))
            # torch.cuda.empty_cache()
            warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(
                g_I1_F_t_1, IFrame) + L1_lossFn(
                    trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(
                        trainFlowBackWarp(I1, F_0_1), I0)

            loss_smooth_1_0 = torch.mean(
                torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))
            loss_smooth_0_1 = torch.mean(
                torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])
            ) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))
            loss_smooth = loss_smooth_1_0 + loss_smooth_0_1
            # torch.cuda.empty_cache()
            # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4
            # since the loss in paper is calculated for input pixels in range 0-255
            # and the input to our network is in range 0-1
            loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth

            # Backpropagate

            loss.backward()
            optimizer.step()
            scheduler.step(loss.item())

            iLoss += loss.item()
            torch.cuda.empty_cache()
            # Validation and progress every `args.progress_iter` iterations
            if ((trainIndex % args.progress_iter) == args.progress_iter - 1):
                # Increment scheduler count
                scheduler.step(iLoss / args.progress_iter)

                end = time.time()

                psnr, vLoss, valImg = validate()
                optimizer.zero_grad()
                # torch.cuda.empty_cache()
                valPSNR[epoch].append(psnr)
                valLoss[epoch].append(vLoss)

                # Tensorboard
                itr = trainIndex + epoch * (len(trainloader))

                writer.add_scalars(
                    'Loss', {
                        'trainLoss': iLoss / args.progress_iter,
                        'validationLoss': vLoss
                    }, itr)
                writer.add_scalar('PSNR', psnr, itr)

                writer.add_image('Validation', valImg, itr)
                #####

                endVal = time.time()

                print(
                    " Loss: %0.6f  Iterations: %4d/%4d  TrainExecTime: %0.1f  ValLoss:%0.6f  ValPSNR: %0.4f  ValEvalTime: %0.2f LearningRate: %.1e"
                    % (iLoss / args.progress_iter, trainIndex,
                       len(trainloader), end - start, vLoss, psnr,
                       endVal - end, get_lr(optimizer)))

                # torch.cuda.empty_cache()
                cLoss[epoch].append(iLoss / args.progress_iter)
                iLoss = 0
                start = time.time()

        # Create checkpoint after every `args.checkpoint_epoch` epochs
        if (epoch % args.checkpoint_epoch) == args.checkpoint_epoch - 1:
            dict1 = {
                'Detail': "End to end Super SloMo.",
                'epoch': epoch,
                'timestamp': datetime.datetime.now(),
                'trainBatchSz': args.train_batch_size,
                'validationBatchSz': args.validation_batch_size,
                'learningRate': get_lr(optimizer),
                'loss': cLoss,
                'valLoss': valLoss,
                'valPSNR': valPSNR,
                'state_dictFC': flowComp.state_dict(),
                'state_dictAT': ArbTimeFlowIntrp.state_dict(),
                'state_optimizer': optimizer.state_dict(),
                'state_scheduler': scheduler.state_dict()
            }
            torch.save(
                dict1, args.checkpoint_dir + "/SuperSloMo" +
                str(checkpoint_counter) + ".ckpt")
            checkpoint_counter += 1