Example #1
0
    def train_vae(self):
        # test batching!
        train_state = self.train_state
        examples = self._examples
        config = self.config
        workspace = self.workspace

        vae_editor = train_state.model.vae_model
        ret_model = train_state.model.ret_model
        edit_model = train_state.model.edit_model
        train_batches = similar_size_batches(examples.train,
                                             config.optim.batch_size)

        vae_editor.test_batch(train_batches[0])

        step = 0
        while step < config.optim.max_iters:
            random.shuffle(train_batches)
            for batch in verboserate(train_batches,
                                     desc='Streaming training examples'):
                loss, _, _ = vae_editor.loss(batch)
                finite_grads, grad_norm = self._take_grad_step(
                    train_state, loss)
                self.check_gradnan(finite_grads, train_state, workspace)
                step = train_state.train_steps
                self.eval_and_save(vae_editor, step, train_state, config,
                                   grad_norm, examples.train, examples.valid)
                if step >= config.optim.max_iters:
                    break
Example #2
0
    def train(self):
        """Train a model.

        NOTE: modifies TrainState in place.
        - parameters of the Editor and Optimizer are updated
        - train_steps is updated
        - random number generator states are updated at every checkpoint
        """
        # TODO(kelvin): do something to preserve random state upon reload?
        train_state = self.train_state
        examples = self._examples
        config = self.config
        workspace = self.workspace
        with random_state(self.train_state.random_state):
            editor = train_state.model
            train_batches = similar_size_batches(examples.train,
                                                 config.optim.batch_size)
            editor.test_batch(train_batches[0])
            best_exact_match_score = 0.0
            while True:
                random.shuffle(train_batches)
                loss = 0
                for batch in verboserate(train_batches,
                                         desc='Streaming training examples'):
                    loss, _, _ = editor.loss(batch)

                    finite_grads, grad_norm = self._take_grad_step(
                        train_state, loss)
                    if not finite_grads:
                        train_state.save(workspace.nan_checkpoints)

                        examples_path = join(
                            workspace.nan_checkpoints,
                            '{}.examples'.format(train_state.train_steps))
                        with open(examples_path, 'w') as f:
                            pickle.dump(batch, f)

                        print 'Gradient was NaN/inf on step {}.'.format(
                            train_state.train_steps)

                    step = train_state.train_steps

                    # run periodic evaluation and saving
                    if step != 0:
                        if step % 10 == 0:
                            self._update_metadata(train_state)
                        if step % config.timing.eval_small == 0:
                            self.evaluate(step, big_eval=False)
                            self.tb_logger.log_value('grad_norm', grad_norm,
                                                     step)
                        if step % config.timing.eval_big == 0:
                            train_stats, valid_stats = self.evaluate(
                                step, big_eval=True)
                            # train_stats, valid_stats = self.evaluate(step, big_eval=False)
                            exact_match_score = valid_stats[('big',
                                                             'exact_match',
                                                             'valid')]
                            self.checkpoints.save(train_state)
                        if step >= config.optim.max_iters:
                            return
Example #3
0
    def _train(cls, config, train_state, examples):
        model = train_state.model
        optimizer = train_state.optimizer
        train_batches = similar_size_batches(examples.train,
                                             config.optim.batch_size,
                                             size=lambda ex: len(ex))

        while True:
            random.shuffle(train_batches)
            i = 0  # cannot enumerate(verboserate(...))
            for batch in verboserate(train_batches,
                                     desc='Streaming training examples'):
                loss = model.loss(batch, cls._train_state.train_steps)
                cls._take_grad_step(train_state, loss)
                if (i % 5) == 0:
                    cls.evaluate()
                if (i % 1000) == 0:
                    if config.model.type == 1:  # SVAE
                        # write interpolations to file
                        fname = "interps_batches_{}".format(i)
                        num_ex = 10
                        a_idx = np.random.randint(len(batch), size=num_ex)
                        b_idx = np.random.randint(len(batch), size=num_ex)
                        interps = []
                        for a, b in zip(a_idx, b_idx):
                            ex_a = batch[a]
                            ex_b = batch[b]
                            interpolation = model._interpolate_examples(
                                ex_a, ex_b)
                            interpolation_repr = []
                            interpolation_repr.append(" ".join(ex_a))
                            interpolation_repr.extend(
                                [" ".join(ex) for ex in interpolation])
                            interpolation_repr.append(" ".join(ex_b))
                            interps.append(interpolation_repr)
                        with open(join(cls._interps_dir, fname), 'w') as fout:
                            data = "\n\n".join(
                                ["\n".join(ex) for ex in interps])
                            fout.write(data.encode('utf-8'))
                if (i % 5000) == 0:
                    cls.checkpoints.save(train_state)
                i += 1
    def train(self):
        config = self.config
        train_state = self.train_state
        model, optimizer = train_state.model, train_state.optimizer

        # group into training batches
        train_batches = similar_size_batches(self.examples.train, batch_size=config.optim.batch_size,
                                             size=lambda x: len(x.output_words))

        def batch_generator():
            while True:
                # WARNING: random state of train state does not exactly restore state anymore, due to this shuffle
                random.shuffle(train_batches)
                for batch in verboserate(train_batches, desc='Streaming example batches'):
                    yield batch

        with random_state(train_state.random_state):
            for batch in batch_generator():
                # take gradient step
                loss = model.loss(batch, config.optim.num_negatives)
                finite_grads = self._take_grad_step(train_state, loss)  # TODO: clip gradient?
                train_steps = train_state.train_steps

                if not finite_grads:
                    print 'WARNING: grads not finite at step {}'.format(train_steps)

                self._update_metadata(train_state)

                # run periodic evaluation and saving
                if train_steps % config.eval.eval_steps == 0:
                    self._evaluate(self.examples, big_eval=False)

                if train_steps % config.eval.big_eval_steps == 0:
                    self._evaluate(self.examples, big_eval=True)

                if train_steps % config.eval.save_steps == 0:
                    self.checkpoints.save(train_state)

                if train_steps >= config.optim.max_iters:
                    return
Example #5
0
    def train_edit(self, use_lsh, topk):
        # TODO(kelvin): do something to preserve random state upon reload?
        train_state = self.train_state
        examples = self._examples
        config = self.config
        workspace = self.workspace

        vae_editor = train_state.model.vae_model
        ret_model = train_state.model.ret_model
        edit_model = train_state.model.edit_model

        # Set up static editor training
        step = train_state.train_steps
        while step < 3 * config.optim.max_iters:
            train_eval = ret_model.ret_and_make_ex(examples.train, use_lsh,
                                                   examples.train, 1)
            valid_eval = ret_model.ret_and_make_ex(examples.valid, use_lsh,
                                                   examples.train, 0)
            ret_batches = similar_size_batches(train_eval,
                                               config.optim.batch_size)
            # random.shuffle(train_batches)
            random.shuffle(ret_batches)
            for batch in verboserate(ret_batches,
                                     desc='Streaming training for retrieval'):
                # Set up pairs to edit on
                fict_batch = edit_model.ident_mapper(batch,
                                                     config.model.ident_pr)
                edit_loss, _, _ = edit_model.loss(fict_batch)
                loss = edit_loss
                finite_grads, grad_norm = self._take_grad_step(
                    train_state, loss)
                self.check_gradnan(finite_grads, train_state, workspace)
                step = train_state.train_steps
                self.eval_and_save(edit_model, step, train_state, config,
                                   grad_norm, train_eval, valid_eval)

                if step >= 3 * config.optim.max_iters:
                    break
                pass
    def _train(cls, config, train_state, examples, workspace, metadata, tb_logger):
        """Train a model.

        NOTE: modifies TrainState in place.
        - parameters of the Editor and Optimizer are updated
        - train_steps is updated
        - random number generator states are updated at every checkpoint

        Args:
            config (Config)
            train_state (TrainState): initial TrainState. Includes the Editor and Optimizer.
            examples (EditDataSplits)
            workspace (Workspace)
            metadata (Metadata)
            tb_logger (tensorboard_logger.Logger)
        """
        with random_state(train_state.random_state):
            editor = train_state.editor
            optimizer = train_state.optimizer
            noiser = EditNoiser(config.editor.ident_pr, config.editor.attend_pr)
            train_batches = similar_size_batches(examples.train, config.optim.batch_size)

            # test batching!
            # commenting out for now, not certain why there is a batching error. 
            #editor.test_batch(noiser(train_batches[0]))

            while True:
                # TODO(kelvin): this shuffle and the position within the shuffle is not properly restored upon reload
                random.shuffle(train_batches)

                for batch in verboserate(train_batches, desc='Streaming training examples'):
                    # compute gradients
                    optimizer.zero_grad()
                    if config.editor.edit_dropout:
                        noised_batch = noiser(batch)
                    else:
                        noised_batch = batch
                    #loss = editor.loss(noised_batch, draw_samples=config.editor.enable_vae)
                    var_loss, var_params, var_param_grads = editor.loss(noised_batch, draw_samples=config.editor.enable_vae)
                    #reg_loss.backward()
                    #loss.backward()

                    """
                    # clip gradients
                    if train_state.train_steps < 50:
                        # don't clip, just observe the gradient norm
                        grad_norm = clip_grad_norm(editor.parameters(), float('inf'), norm_type=2)
                        train_state.track_grad_norms(grad_norm)
                        metadata['max_grad_norm'] = train_state.max_grad_norm
                    else:
                        # clip according to the max allowed grad norm
                        grad_norm = clip_grad_norm(editor.parameters(), train_state.max_grad_norm)
                        # this returns the gradient norm BEFORE clipping
                    """

                    # Always do gradient clipping
                    # To-do: make this tunable, not hard-coded
                    grad_norm = clip_grad_norm(editor.parameters(), 5.)
                    #storch.nn.utils.clip_grad_norm(editor.parameters(), 5.0) 

                    finite_grads = cls._finite_grads(editor.parameters())
                    #cur = [param for param in editor.parameters()]

                    # take a step if the grads are finite
                    if finite_grads:
                        optimizer.step()

                    # increment step count
                    train_state.increment_train_steps()

                    # somehow we encountered NaN
                    if not finite_grads:
                        # dump parameters
                        train_state.save(workspace.nan_checkpoints)

                        # dump offending example batch
                        examples_path = join(workspace.nan_checkpoints, '{}.examples'.format(train_state.train_steps))
                        with open(examples_path, 'w') as f:
                            pickle.dump(noised_batch, f)

                        print 'Gradient was NaN/inf on step {}.'.format(train_state.train_steps)

                        # if there were more than 5 NaNs in the last 10 steps, drop into the debugger
                        nan_steps = cls._checkpoint_numbers(workspace.nan_checkpoints)
                        recent_nans = [s for s in nan_steps if s > train_state.train_steps - 10]
                        if len(recent_nans) > 5:
                            print 'Too many NaNs encountered recently: {}. Entering debugger.'.format(recent_nans)
                            import pdb
                            pdb.set_trace()

                    # run periodic evaluation and saving
                    if train_state.train_steps % config.eval.eval_steps == 0:
                        cls._evaluate(config, editor, examples, metadata, tb_logger, train_state.train_steps, noiser, big_eval=False)
                        tb_logger.log_value('grad_norm', grad_norm, train_state.train_steps)

                    if train_state.train_steps % config.eval.big_eval_steps == 0:
                        cls._evaluate(config, editor, examples, metadata, tb_logger, train_state.train_steps, noiser, big_eval=True)

                    if train_state.train_steps % config.eval.save_steps == 0:
                        train_state.update_random_state()
                        train_state.save(workspace.checkpoints)

                    if train_state.train_steps >= config.optim.max_iters:
                        return