Exemple #1
0
    def train(self,
              n_iters,
              lr,
              lamb,
              batch_size=None,
              alpha=None,
              test=False,
              print_interval=100,
              store_params=False,
              save_full_loss=False):
        """train the model for a number of iterations

        :n_iter: integer, number of iterations to run
        :param lr: learning rate, constant or function of iteration
        :param lamb: regularization parameter, constant
        :param batch_size:
            constant or function of iteration, use full dataset if not given
        :param alpha: function of iteration, ignored if not given
        """
        proximal_op = proximal.SoftThresholding(lamb)
        if alpha is None:
            optimizer = fista.ForwardBackward(
                self.model.parameters(),
                1,
                proximal_op,
                regularize_idxs=self.regularize_idxs)
        else:
            optimizer = fista.FISTA(self.model.parameters(),
                                    1,
                                    proximal_op,
                                    regularize_idxs=self.regularize_idxs)
        if isinstance(lr, numbers.Number):
            decay = lambda _: lr
        else:
            decay = lambda k: lr(k)
        # Should update learning rate for torch.optim.optimizer
        scheduler = LambdaLR(optimizer.optimizer, lr_lambda=decay)
        scheduler.last_epoch = self.counter
        for _ in range(n_iters):
            self._train_step(optimizer, scheduler, batch_size, alpha)
            self.l1_losses.append(self.l1_loss() * lamb)
            if test:
                self.test(update=True)
            if store_params:
                self.params_his.append(deepcopy(list(self.model.parameters())))
            if self.counter % print_interval == 0:
                self.log(test)
            if save_full_loss:
                outputs = self.model(self.data)
                full_loss = F.cross_entropy(outputs, self.target).item()
                full_loss += self.l1_losses[-1]
                self.full_losses.append(full_loss)
Exemple #2
0
def main(DEVICE):
    """
    main function

    :param DEVICE: 'cpu' or 'gpu'

    """
    model = TPGST().to(DEVICE)

    print('Model {} is working...'.format(type(model).__name__))
    ckpt_dir = os.path.join(args.logdir, type(model).__name__)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler = LambdaLR(optimizer, lr_policy)

    if not os.path.exists(ckpt_dir):
        os.makedirs(os.path.join(ckpt_dir, 'A', 'train'))
    else:
        print('Already exists. Retrain the model.')
        model_path = sorted(glob.glob(os.path.join(
            ckpt_dir, 'model-*.tar')))[-1]  # latest model
        state = torch.load(model_path)
        model.load_state_dict(state['model'])
        args.global_step = state['global_step']
        optimizer.load_state_dict(state['optimizer'])
        scheduler.last_epoch = state['scheduler']['last_epoch']
        scheduler.base_lrs = state['scheduler']['base_lrs']

    dataset = SpeechDataset(args.data_path,
                            args.meta,
                            mem_mode=args.mem_mode,
                            training=True)
    validset = SpeechDataset(args.data_path,
                             args.meta,
                             mem_mode=args.mem_mode,
                             training=False)
    data_loader = DataLoader(dataset=dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             collate_fn=collate_fn,
                             drop_last=True,
                             pin_memory=True,
                             num_workers=args.n_workers)
    valid_loader = DataLoader(dataset=validset,
                              batch_size=args.test_batch,
                              shuffle=False,
                              collate_fn=collate_fn,
                              pin_memory=True)
    # torch.set_num_threads(4)
    print('{} threads are used...'.format(torch.get_num_threads()))

    writer = SummaryWriter(ckpt_dir)
    train(model,
          data_loader,
          valid_loader,
          optimizer,
          scheduler,
          batch_size=args.batch_size,
          ckpt_dir=ckpt_dir,
          writer=writer,
          DEVICE=DEVICE)
    return None