def test_checkpoint_eval(): model = nn.Sequential(nn.Linear(1, 1)) model = GPipe(model, balance=[1], devices=['cpu'], chunks=2) input = torch.rand(2, 1) def find_grad_fn(grad_fn, name): if grad_fn is None: return False if grad_fn.__class__.__name__ == name: return True for next_grad_fn, _ in grad_fn.next_functions: if find_grad_fn(next_grad_fn, name): return True return False model.train() train_output = model(input) assert find_grad_fn(train_output.grad_fn, 'CheckpointBackward') assert find_grad_fn(train_output.grad_fn, 'RecomputeBackward') model.eval() eval_output = model(input) assert not find_grad_fn(eval_output.grad_fn, 'CheckpointBackward') assert not find_grad_fn(eval_output.grad_fn, 'RecomputeBackward')
class GPipeModel(object): def __init__(self, model_name, model_path, gradient_clip_value=5.0, device_ids=None, **kwargs): gpipe_model = nn.Sequential(gpipe_encoder(model_name, **kwargs), gpipe_decoder(model_name, **kwargs)) self.model = GPipe(gpipe_model, balance=[1, 1], chunks=2) self.in_device = self.model.devices[0] self.out_device = self.model.devices[-1] self.loss_fn = nn.BCEWithLogitsLoss() self.model_path, self.state = model_path, {} os.makedirs(os.path.split(self.model_path)[0], exist_ok=True) self.gradient_clip_value, self.gradient_norm_queue = gradient_clip_value, deque( [np.inf], maxlen=5) self.optimizer = None def train_step(self, train_x: torch.Tensor, train_y: torch.Tensor): self.optimizer.zero_grad() self.model.train() scores = self.model(train_x) loss = self.loss_fn(scores, train_y) loss.backward() self.clip_gradient() self.optimizer.step(closure=None) return loss.item() def predict_step(self, data_x: torch.Tensor, k: int): self.model.eval() with torch.no_grad(): scores, labels = torch.topk(self.model(data_x), k) return torch.sigmoid(scores).cpu(), labels.cpu() def get_optimizer(self, **kwargs): self.optimizer = DenseSparseAdam(self.model.parameters(), **kwargs) def train(self, train_loader: DataLoader, valid_loader: DataLoader, opt_params: Optional[Mapping] = None, nb_epoch=100, step=100, k=5, early=100, verbose=True, swa_warmup=None, **kwargs): self.get_optimizer(**({} if opt_params is None else opt_params)) global_step, best_n5, e = 0, 0.0, 0 print_loss = 0.0 # for epoch_idx in range(nb_epoch): if epoch_idx == swa_warmup: self.swa_init() for i, (train_x, train_y) in enumerate(train_loader, 1): global_step += 1 loss = self.train_step( train_x.to(self.in_device, non_blocking=True), train_y.to(self.out_device, non_blocking=True)) print_loss += loss # if global_step % step == 0: self.swa_step() self.swap_swa_params() ## labels = [] valid_loss = 0.0 self.model.eval() with torch.no_grad(): for (valid_x, valid_y) in valid_loader: logits = self.model( valid_x.to(self.in_device, non_blocking=True)) valid_loss += self.loss_fn( logits, valid_y.to(self.out_device, non_blocking=True)).item() scores, tmp = torch.topk(logits, k) labels.append(tmp.cpu()) valid_loss /= len(valid_loader) labels = np.concatenate(labels) ## # labels = np.concatenate([self.predict_step(valid_x, k)[1] for valid_x in valid_loader]) targets = valid_loader.dataset.data_y p5, n5 = get_p_5(labels, targets), get_n_5(labels, targets) if n5 > best_n5: self.save_model(epoch_idx > 3 * swa_warmup) best_n5, e = n5, 0 else: e += 1 if early is not None and e > early: return self.swap_swa_params() if verbose: log_msg = '%d %d train loss: %.7f valid loss: %.7f P@5: %.5f N@5: %.5f early stop: %d' % \ (epoch_idx, i * train_loader.batch_size, print_loss / step, valid_loss, round(p5, 5), round(n5, 5), e) logger.info(log_msg) print_loss = 0.0 def predict(self, data_loader: DataLoader, k=100, desc='Predict', **kwargs): self.load_model() scores_list, labels_list = zip(*( self.predict_step(data_x.to(self.in_device, non_blocking=True), k) for data_x in tqdm(data_loader, desc=desc, leave=False))) return np.concatenate(scores_list), np.concatenate(labels_list) def save_model(self, last_epoch): if not last_epoch: return for trial in range(5): try: torch.save(self.model.state_dict(), self.model_path) break except: print('saving failed') def load_model(self): self.model.load_state_dict(torch.load(self.model_path)) def clip_gradient(self): if self.gradient_clip_value is not None: max_norm = max(self.gradient_norm_queue) total_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), max_norm * self.gradient_clip_value) self.gradient_norm_queue.append( min(total_norm, max_norm * 2.0, 1.0)) if total_norm > max_norm * self.gradient_clip_value: logger.warn( F'Clipping gradients with total norm {round(total_norm, 5)} ' F'and max norm {round(max_norm, 5)}') def swa_init(self): if 'swa' not in self.state: logger.info('SWA Initializing') swa_state = self.state['swa'] = {'models_num': 1} for n, p in self.model.named_parameters(): swa_state[n] = p.data.cpu().detach() def swa_step(self): if 'swa' in self.state: swa_state = self.state['swa'] swa_state['models_num'] += 1 beta = 1.0 / swa_state['models_num'] with torch.no_grad(): for n, p in self.model.named_parameters(): swa_state[n].mul_(1.0 - beta).add_(beta, p.data.cpu()) def swap_swa_params(self): if 'swa' in self.state: swa_state = self.state['swa'] for n, p in self.model.named_parameters(): gpu_id = p.get_device() p.data, swa_state[n] = swa_state[n], p.data.cpu() # p.data = p.data.cuda(gpu_id) def disable_swa(self): if 'swa' in self.state: del self.state['swa']
def test_delete_portal_tensor(train, checkpoint): # Without checkpointing: # +- Stash --+ +--- Pop ----+ - - - layers # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function # +----------+ +------------+ # # With checkpointing: # +- Stash --+ +--- Pop ----+ +--- Pop'----+ +- Stash'--+ # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | # +----------+ +------------+ +------------+ +----------+ def portal_tensor_life_is(tensor_life, skip_tracker=None): if skip_tracker is None: skip_tracker = current_skip_tracker() # Get the current portal. portal = list(skip_tracker.portals.values())[0] if tensor_life == 0: return portal.tensor_life == 0 and portal.tensor is None else: return portal.tensor_life == tensor_life and portal.tensor is not None # Check the portal tensor after 'Stash'. stash_ = Stash() @stash_.register_forward_hook def check_portal_tensor_after_stash(*_): if is_checkpointing(): assert portal_tensor_life_is(2) elif is_recomputing(): assert portal_tensor_life_is(0) else: assert portal_tensor_life_is(1) pop_ = Pop() @pop_.register_forward_hook def check_portal_tensor_after_pop(*_): if is_checkpointing(): assert portal_tensor_life_is(1) elif is_recomputing(): assert portal_tensor_life_is(0) else: assert portal_tensor_life_is(0) class NoPortalTensorAtBackward(nn.Module): class F(torch.autograd.Function): @staticmethod def forward(ctx, input): ctx.skip_tracker = current_skip_tracker() return input.detach() @staticmethod def backward(ctx, grad): assert portal_tensor_life_is(0, skip_tracker=ctx.skip_tracker) return grad def forward(self, input): return self.F.apply(input) model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) model = GPipe(model, balance=[2, 1], devices=['cpu', 'cpu'], chunks=2, checkpoint=checkpoint) input = torch.rand(10, requires_grad=True) if train: model.train() output = model(input) output.norm().backward() else: model.eval() with torch.no_grad(): model(input)