def train(self, n_iters, lr, lamb, batch_size=None, alpha=None, test=False, print_interval=100, store_params=False, save_full_loss=False): """train the model for a number of iterations :n_iter: integer, number of iterations to run :param lr: learning rate, constant or function of iteration :param lamb: regularization parameter, constant :param batch_size: constant or function of iteration, use full dataset if not given :param alpha: function of iteration, ignored if not given """ proximal_op = proximal.SoftThresholding(lamb) if alpha is None: optimizer = fista.ForwardBackward( self.model.parameters(), 1, proximal_op, regularize_idxs=self.regularize_idxs) else: optimizer = fista.FISTA(self.model.parameters(), 1, proximal_op, regularize_idxs=self.regularize_idxs) if isinstance(lr, numbers.Number): decay = lambda _: lr else: decay = lambda k: lr(k) # Should update learning rate for torch.optim.optimizer scheduler = LambdaLR(optimizer.optimizer, lr_lambda=decay) scheduler.last_epoch = self.counter for _ in range(n_iters): self._train_step(optimizer, scheduler, batch_size, alpha) self.l1_losses.append(self.l1_loss() * lamb) if test: self.test(update=True) if store_params: self.params_his.append(deepcopy(list(self.model.parameters()))) if self.counter % print_interval == 0: self.log(test) if save_full_loss: outputs = self.model(self.data) full_loss = F.cross_entropy(outputs, self.target).item() full_loss += self.l1_losses[-1] self.full_losses.append(full_loss)
def main(DEVICE): """ main function :param DEVICE: 'cpu' or 'gpu' """ model = TPGST().to(DEVICE) print('Model {} is working...'.format(type(model).__name__)) ckpt_dir = os.path.join(args.logdir, type(model).__name__) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = LambdaLR(optimizer, lr_policy) if not os.path.exists(ckpt_dir): os.makedirs(os.path.join(ckpt_dir, 'A', 'train')) else: print('Already exists. Retrain the model.') model_path = sorted(glob.glob(os.path.join( ckpt_dir, 'model-*.tar')))[-1] # latest model state = torch.load(model_path) model.load_state_dict(state['model']) args.global_step = state['global_step'] optimizer.load_state_dict(state['optimizer']) scheduler.last_epoch = state['scheduler']['last_epoch'] scheduler.base_lrs = state['scheduler']['base_lrs'] dataset = SpeechDataset(args.data_path, args.meta, mem_mode=args.mem_mode, training=True) validset = SpeechDataset(args.data_path, args.meta, mem_mode=args.mem_mode, training=False) data_loader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True, pin_memory=True, num_workers=args.n_workers) valid_loader = DataLoader(dataset=validset, batch_size=args.test_batch, shuffle=False, collate_fn=collate_fn, pin_memory=True) # torch.set_num_threads(4) print('{} threads are used...'.format(torch.get_num_threads())) writer = SummaryWriter(ckpt_dir) train(model, data_loader, valid_loader, optimizer, scheduler, batch_size=args.batch_size, ckpt_dir=ckpt_dir, writer=writer, DEVICE=DEVICE) return None