def train_vae(self): # test batching! train_state = self.train_state examples = self._examples config = self.config workspace = self.workspace vae_editor = train_state.model.vae_model ret_model = train_state.model.ret_model edit_model = train_state.model.edit_model train_batches = similar_size_batches(examples.train, config.optim.batch_size) vae_editor.test_batch(train_batches[0]) step = 0 while step < config.optim.max_iters: random.shuffle(train_batches) for batch in verboserate(train_batches, desc='Streaming training examples'): loss, _, _ = vae_editor.loss(batch) finite_grads, grad_norm = self._take_grad_step( train_state, loss) self.check_gradnan(finite_grads, train_state, workspace) step = train_state.train_steps self.eval_and_save(vae_editor, step, train_state, config, grad_norm, examples.train, examples.valid) if step >= config.optim.max_iters: break
def train(self): """Train a model. NOTE: modifies TrainState in place. - parameters of the Editor and Optimizer are updated - train_steps is updated - random number generator states are updated at every checkpoint """ # TODO(kelvin): do something to preserve random state upon reload? train_state = self.train_state examples = self._examples config = self.config workspace = self.workspace with random_state(self.train_state.random_state): editor = train_state.model train_batches = similar_size_batches(examples.train, config.optim.batch_size) editor.test_batch(train_batches[0]) best_exact_match_score = 0.0 while True: random.shuffle(train_batches) loss = 0 for batch in verboserate(train_batches, desc='Streaming training examples'): loss, _, _ = editor.loss(batch) finite_grads, grad_norm = self._take_grad_step( train_state, loss) if not finite_grads: train_state.save(workspace.nan_checkpoints) examples_path = join( workspace.nan_checkpoints, '{}.examples'.format(train_state.train_steps)) with open(examples_path, 'w') as f: pickle.dump(batch, f) print 'Gradient was NaN/inf on step {}.'.format( train_state.train_steps) step = train_state.train_steps # run periodic evaluation and saving if step != 0: if step % 10 == 0: self._update_metadata(train_state) if step % config.timing.eval_small == 0: self.evaluate(step, big_eval=False) self.tb_logger.log_value('grad_norm', grad_norm, step) if step % config.timing.eval_big == 0: train_stats, valid_stats = self.evaluate( step, big_eval=True) # train_stats, valid_stats = self.evaluate(step, big_eval=False) exact_match_score = valid_stats[('big', 'exact_match', 'valid')] self.checkpoints.save(train_state) if step >= config.optim.max_iters: return
def _train(cls, config, train_state, examples): model = train_state.model optimizer = train_state.optimizer train_batches = similar_size_batches(examples.train, config.optim.batch_size, size=lambda ex: len(ex)) while True: random.shuffle(train_batches) i = 0 # cannot enumerate(verboserate(...)) for batch in verboserate(train_batches, desc='Streaming training examples'): loss = model.loss(batch, cls._train_state.train_steps) cls._take_grad_step(train_state, loss) if (i % 5) == 0: cls.evaluate() if (i % 1000) == 0: if config.model.type == 1: # SVAE # write interpolations to file fname = "interps_batches_{}".format(i) num_ex = 10 a_idx = np.random.randint(len(batch), size=num_ex) b_idx = np.random.randint(len(batch), size=num_ex) interps = [] for a, b in zip(a_idx, b_idx): ex_a = batch[a] ex_b = batch[b] interpolation = model._interpolate_examples( ex_a, ex_b) interpolation_repr = [] interpolation_repr.append(" ".join(ex_a)) interpolation_repr.extend( [" ".join(ex) for ex in interpolation]) interpolation_repr.append(" ".join(ex_b)) interps.append(interpolation_repr) with open(join(cls._interps_dir, fname), 'w') as fout: data = "\n\n".join( ["\n".join(ex) for ex in interps]) fout.write(data.encode('utf-8')) if (i % 5000) == 0: cls.checkpoints.save(train_state) i += 1
def train(self): config = self.config train_state = self.train_state model, optimizer = train_state.model, train_state.optimizer # group into training batches train_batches = similar_size_batches(self.examples.train, batch_size=config.optim.batch_size, size=lambda x: len(x.output_words)) def batch_generator(): while True: # WARNING: random state of train state does not exactly restore state anymore, due to this shuffle random.shuffle(train_batches) for batch in verboserate(train_batches, desc='Streaming example batches'): yield batch with random_state(train_state.random_state): for batch in batch_generator(): # take gradient step loss = model.loss(batch, config.optim.num_negatives) finite_grads = self._take_grad_step(train_state, loss) # TODO: clip gradient? train_steps = train_state.train_steps if not finite_grads: print 'WARNING: grads not finite at step {}'.format(train_steps) self._update_metadata(train_state) # run periodic evaluation and saving if train_steps % config.eval.eval_steps == 0: self._evaluate(self.examples, big_eval=False) if train_steps % config.eval.big_eval_steps == 0: self._evaluate(self.examples, big_eval=True) if train_steps % config.eval.save_steps == 0: self.checkpoints.save(train_state) if train_steps >= config.optim.max_iters: return
def train_edit(self, use_lsh, topk): # TODO(kelvin): do something to preserve random state upon reload? train_state = self.train_state examples = self._examples config = self.config workspace = self.workspace vae_editor = train_state.model.vae_model ret_model = train_state.model.ret_model edit_model = train_state.model.edit_model # Set up static editor training step = train_state.train_steps while step < 3 * config.optim.max_iters: train_eval = ret_model.ret_and_make_ex(examples.train, use_lsh, examples.train, 1) valid_eval = ret_model.ret_and_make_ex(examples.valid, use_lsh, examples.train, 0) ret_batches = similar_size_batches(train_eval, config.optim.batch_size) # random.shuffle(train_batches) random.shuffle(ret_batches) for batch in verboserate(ret_batches, desc='Streaming training for retrieval'): # Set up pairs to edit on fict_batch = edit_model.ident_mapper(batch, config.model.ident_pr) edit_loss, _, _ = edit_model.loss(fict_batch) loss = edit_loss finite_grads, grad_norm = self._take_grad_step( train_state, loss) self.check_gradnan(finite_grads, train_state, workspace) step = train_state.train_steps self.eval_and_save(edit_model, step, train_state, config, grad_norm, train_eval, valid_eval) if step >= 3 * config.optim.max_iters: break pass
def _train(cls, config, train_state, examples, workspace, metadata, tb_logger): """Train a model. NOTE: modifies TrainState in place. - parameters of the Editor and Optimizer are updated - train_steps is updated - random number generator states are updated at every checkpoint Args: config (Config) train_state (TrainState): initial TrainState. Includes the Editor and Optimizer. examples (EditDataSplits) workspace (Workspace) metadata (Metadata) tb_logger (tensorboard_logger.Logger) """ with random_state(train_state.random_state): editor = train_state.editor optimizer = train_state.optimizer noiser = EditNoiser(config.editor.ident_pr, config.editor.attend_pr) train_batches = similar_size_batches(examples.train, config.optim.batch_size) # test batching! # commenting out for now, not certain why there is a batching error. #editor.test_batch(noiser(train_batches[0])) while True: # TODO(kelvin): this shuffle and the position within the shuffle is not properly restored upon reload random.shuffle(train_batches) for batch in verboserate(train_batches, desc='Streaming training examples'): # compute gradients optimizer.zero_grad() if config.editor.edit_dropout: noised_batch = noiser(batch) else: noised_batch = batch #loss = editor.loss(noised_batch, draw_samples=config.editor.enable_vae) var_loss, var_params, var_param_grads = editor.loss(noised_batch, draw_samples=config.editor.enable_vae) #reg_loss.backward() #loss.backward() """ # clip gradients if train_state.train_steps < 50: # don't clip, just observe the gradient norm grad_norm = clip_grad_norm(editor.parameters(), float('inf'), norm_type=2) train_state.track_grad_norms(grad_norm) metadata['max_grad_norm'] = train_state.max_grad_norm else: # clip according to the max allowed grad norm grad_norm = clip_grad_norm(editor.parameters(), train_state.max_grad_norm) # this returns the gradient norm BEFORE clipping """ # Always do gradient clipping # To-do: make this tunable, not hard-coded grad_norm = clip_grad_norm(editor.parameters(), 5.) #storch.nn.utils.clip_grad_norm(editor.parameters(), 5.0) finite_grads = cls._finite_grads(editor.parameters()) #cur = [param for param in editor.parameters()] # take a step if the grads are finite if finite_grads: optimizer.step() # increment step count train_state.increment_train_steps() # somehow we encountered NaN if not finite_grads: # dump parameters train_state.save(workspace.nan_checkpoints) # dump offending example batch examples_path = join(workspace.nan_checkpoints, '{}.examples'.format(train_state.train_steps)) with open(examples_path, 'w') as f: pickle.dump(noised_batch, f) print 'Gradient was NaN/inf on step {}.'.format(train_state.train_steps) # if there were more than 5 NaNs in the last 10 steps, drop into the debugger nan_steps = cls._checkpoint_numbers(workspace.nan_checkpoints) recent_nans = [s for s in nan_steps if s > train_state.train_steps - 10] if len(recent_nans) > 5: print 'Too many NaNs encountered recently: {}. Entering debugger.'.format(recent_nans) import pdb pdb.set_trace() # run periodic evaluation and saving if train_state.train_steps % config.eval.eval_steps == 0: cls._evaluate(config, editor, examples, metadata, tb_logger, train_state.train_steps, noiser, big_eval=False) tb_logger.log_value('grad_norm', grad_norm, train_state.train_steps) if train_state.train_steps % config.eval.big_eval_steps == 0: cls._evaluate(config, editor, examples, metadata, tb_logger, train_state.train_steps, noiser, big_eval=True) if train_state.train_steps % config.eval.save_steps == 0: train_state.update_random_state() train_state.save(workspace.checkpoints) if train_state.train_steps >= config.optim.max_iters: return