def main():
    """Meant to be invoked for this runner"""
    folderpath = os.path.join('out', 'mytyping', 'runners',
                              'transfer_train_ussp', '1layer', '0', '1')
    words = mwords.load_custom(os.path.join(folderpath, 'words.txt'))
    ssp = ussp.UniformSSP(words.words, 64)

    network = torch.load(
        os.path.join(folderpath, 'trained_models', 'epoch_finished.pt'))

    teacher = ss1.EncoderDecoderTeacher(menc.stop_failer, 30)
    _eval(ssp, teacher, network)
    tracker = mtnr.AccuracyTracker(1, len(words.words), False)

    _logger = logging.getLogger(__name__)
    _logger.setLevel(logging.DEBUG)
    logging.basicConfig(format='%(asctime)s %(message)s',
                        datefmt='%m/%d/%Y %I:%M:%S %p')
    ctx = stnr.SSPGenericTrainingContext(model=network,
                                         teacher=teacher,
                                         train_ssp=ssp,
                                         test_ssp=ssp,
                                         optimizers=[],
                                         batch_size=1,
                                         shared={},
                                         logger=_logger,
                                         perf_stats=NoopPerfStats())
    ctx.shared['epochs'] = tnr.EpochsTracker()
    ctx.shared['epochs'].new_epoch = True

    tracker.setup(ctx)
    tracker.pre_loop(ctx)
Esempio n. 2
0
 def teach_many(self, network: EncoderDecoder, optimizers: typing.List[torch.optim.Optimizer],
                criterion: typing.Any, inputs: typing.List[Sequence],
                outputs: typing.List[Sequence],
                perf_stats: PerfStats = NoopPerfStats()) -> float:
     all_losses = torch.zeros(len(inputs), dtype=torch.double)
     for ind, inp in enumerate(inputs):
         out = outputs[ind]
         all_losses[ind] = self.teach_single(network, optimizers, criterion, inp, out, perf_stats)
     return all_losses.mean().item()
Esempio n. 3
0
 def classify_many(self, network: EncoderDecoder,
                   inputs: typing.List[Sequence],
                   perf_stats: PerfStats = NoopPerfStats()) -> typing.List[Sequence]:
     result = []
     with torch.set_grad_enabled(False):
         for inp in inputs:
             inp_tensor = inp.raw if torch.is_tensor(inp.raw) else torch.tensor(inp.raw, dtype=torch.double) # pylint: disable=not-callable, line-too-long
             reader = EDReader(self.stop_failer, self.max_out_len)
             network(inp_tensor, reader, perf_stats)
             result.append(Sequence(raw=reader.output))
     return result
    def get_current(
        self, perf_stats: PerfStats = NoopPerfStats()
    ) -> typing.Tuple[Sequence, Sequence]:
        if PRE_ENCODE:
            return self.encoded_words[self.position]

        word = self.words[self.position]
        perf_stats.enter('ENCODE_INPUT')
        inputs = [menc.encode_input(i) for i in word]
        inputs.append(menc.encode_input_stop())
        perf_stats.exit_then_enter('ENCODE_OUTPUT')
        outputs = [menc.encode_output(i, self.char_delay) for i in word]
        outputs.append(menc.encode_output_stop())
        perf_stats.exit()

        return Sequence(raw=inputs), Sequence(raw=outputs)
Esempio n. 5
0
    def forward(self, inp: torch.tensor, callback: typing.Callable = None, # pylint: disable=arguments-differ
                perf_stats: PerfStats = NoopPerfStats()) -> torch.tensor:
        """Goes from the input vector to a sequence. The sequence is returned by invoking the
        callback at each step, which must return True if the network needs to continue or False
        if the network is finished. This operation cannot be batched.

        If you need to do teacher-forcing you can have the callback return True and then reinvoke

        Arguments:
            inp (tensor [input_dim]): the context vector to forward through the network
            callback (function(output (tensor [output_dim])) -> bool, optional): accepts the step
                of the network and returns True to continue generating a sequence and False to stop.
                If None, acts like lambda out: True.
            perf_stats (PerfStats, optional): used to handle performance tracking. assumed
                to have been started and entered a region already that identifies this as the
                decoder rnn
        Returns:
            out (tensor [output_dim]): the last output of the network
        """
        tus.check_tensors(inp=(inp, [('input_dim', self.input_dim)], torch.double))
        perf_stats.enter('CONTEXT_TO_HIDDEN')
        inp_interp = self.context_to_hidden(inp)

        perf_stats.exit_then_enter('HIDDEN_THROUGH_GRU')
        hidden, state = self.hidden_through_gru(inp_interp, None)
        perf_stats.exit()

        if callback is None:
            return hidden, state
        else:
            perf_stats.enter('HIDDEN_TO_OUTPUT')
            out = self.hidden_to_output(hidden, state)
            perf_stats.exit()
            if not callback(out):
                return hidden, state

        while True:
            perf_stats.enter('HIDDEN_THROUGH_GRU')
            hidden, state = self.hidden_through_gru(hidden, state)

            perf_stats.exit_then_enter('HIDDEN_TO_OUTPUT')
            out = self.hidden_to_output(hidden, state)
            perf_stats.exit()
            if not callback(out):
                return hidden, state
Esempio n. 6
0
    def forward(self, inp: torch.tensor, # pylint: disable=arguments-differ
                perf_stats: PerfStats = NoopPerfStats()) -> torch.tensor:
        """Presents the given string to the network and returns the context state

        Arguments:
            inp (torch.tensor [batch size, sequence length, input size]): tensor containing the
                input features. May send multiple batches of the same sequence length at once.
                The batch is indicated with the first index.  The second index corresponds to the
                timestep of the sequence, and the last index tells us which feature
            perf_stats (PerfStats): used for performance tracking, assumed to have already been
                started and within a region that identifies this object. Use NoopPerfStats to
                not use this.
        Returns:
            context_vector (torch.tensor[batch_size, output_size])
        """
        tus.check_tensors(inp=(inp, (('batch_size', None), ('sequence length', None),
                                     ('input size', self.input_dim)), torch.double))

        perf_stats.enter('IN_INTERP')
        interpinp = self.in_interpreter(inp.reshape(inp.shape[0] * inp.shape[1], inp.shape[2]))
        perf_stats.exit_then_enter('IN_NONLIN')
        interpinp = self.in_nonlinearity(
            interpinp
        ).reshape(inp.shape[0], inp.shape[1], self.hidden_size)
        perf_stats.exit_then_enter('GRU')
        out1, out2 = self.gru(interpinp)
        perf_stats.exit()
        out2 = out2.transpose(0, 1) # pytorch bug I'm fairly sure
        tus.check_tensors(out1=(out1, (('batch_size', inp.shape[0]), ('sequence length', inp.shape[1]),
                                       ('hidden_size', self.hidden_size)), torch.double),
                          out2=(out2, (('batch_size', inp.shape[0]), ('num layers', self.num_layers),
                                       ('hidden_size', self.hidden_size)), torch.double))
        out1_last = out1[:, -1, :].reshape(out1.shape[0], self.hidden_size)
        out2_reshaped = out2.reshape(out2.shape[0], self.num_layers * self.hidden_size)
        perf_stats.enter('CAT')
        stacked = torch.cat((out1_last, out2_reshaped), dim=1)
        perf_stats.exit_then_enter('OUT_INTERP')
        out = self.out_interpreter(stacked)
        perf_stats.exit_then_enter('OUT_NONLIN')
        out = self.out_nonlinearity(out)
        perf_stats.exit()
        tus.check_tensors(out=(out, (('batch_size', inp.shape[0]),
                                     ('output_size', self.output_dim)), torch.double))
        return out
Esempio n. 7
0
    def forward(self, inp: torch.tensor, callback: typing.Callable, # pylint: disable=arguments-differ
                perf_stats: PerfStats = NoopPerfStats()) -> None:
        """Forwards through the network. This is not sufficient information for training really,
        which should probably understand the encoding / decoding dichotomy. However, it does make
        the forward pass pretty easy.

        Attributes:
            inp (torch.tensor [input_sequence_length, input_dim]): the sequence to present to the
                network.
            callback (function(tensor[output_dim]) -> bool): the callback is presented with each
                output element of the sequence until it returns False
            perf_stats (optional): used to store performance information
        """

        perf_stats.enter('ENCODE')
        context = self.encoder(inp.unsqueeze(dim=0), perf_stats)
        perf_stats.exit_then_enter('DECODE')
        self.decoder(context.squeeze(), callback, perf_stats)
        perf_stats.exit()
Esempio n. 8
0
 def get_current(
     self, perf_stats: PerfStats = NoopPerfStats()
 ) -> typing.Tuple[Sequence, Sequence]:
     """Gets the value at the current position. Returns the input and the output in that order"""
     raise NotImplementedError()
Esempio n. 9
0
    def train(self, model: Network, **kwargs):
        """Trains the given sequence to sequence model"""

        if 'logger' in kwargs:
            _logger = kwargs['logger']
            del kwargs['logger']
        else:
            _logger = logging.getLogger(__name__)
            _logger.setLevel(logging.DEBUG)
            logging.basicConfig(format='%(asctime)s %(message)s',
                                datefmt='%m/%d/%Y %I:%M:%S %p')

        if 'perf_stats' in kwargs:
            perf_stats = kwargs['perf_stats']
            del kwargs['perf_stats']
        else:
            perf_stats = NoopPerfStats()

        context = SSPGenericTrainingContext(model=model,
                                            teacher=self.teacher,
                                            train_ssp=self.train_ssp,
                                            test_ssp=self.test_ssp,
                                            batch_size=self.batch_size,
                                            optimizers=self.optimizers,
                                            shared=dict(),
                                            logger=_logger,
                                            perf_stats=perf_stats)
        del _logger
        del perf_stats

        if self.learning_rate is not None:
            for optim in context.optimizers:
                for param_group in optim.param_groups:
                    param_group['lr'] = self.learning_rate

        self.setup(context, **kwargs)

        while True:
            context.perf_stats.enter('PRE_LOOP')
            self.pre_loop(context)
            context.perf_stats.exit_then_enter('GET_INPUTS')

            inputs = []
            outputs = []
            for _ in range(context.batch_size):
                inp, out = next(context.train_ssp)
                inputs.append(inp)
                outputs.append(out)

            context.perf_stats.exit_then_enter('PRE_TRAIN')
            self.pre_train(context)
            context.perf_stats.exit_then_enter('TEACH_MANY')
            loss = context.teacher.teach_many(context.model,
                                              context.optimizers,
                                              self.criterion, inputs, outputs,
                                              context.perf_stats)
            context.perf_stats.exit_then_enter('POST_TRAIN')
            self.post_train(context, loss)
            context.perf_stats.exit_then_enter('DECAY_SCHEDULER')
            if self.decay_scheduler(context, loss, False):
                context.perf_stats.exit_then_enter('DECAY')
                context = self.decay(context)
            context.perf_stats.exit_then_enter('STOPPER')
            if self.stopper(context):
                context.perf_stats.exit()
                break
            context.perf_stats.exit()

        result = dict()
        context.perf_stats.enter('FINISHED')
        self.finished(context, result)
        context.perf_stats.exit()
        return result