Beispiel #1
0
    def test_dataloader(self):
        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" Creating dataloader train")

        data_set = pd.read_pickle(
            f"{self.global_state.data_dir}/allData.pickle")

        data_set = data_set.fillna(0.0).values[:, 1:].astype(np.float32)
        data_set = torch.tensor(data_set)

        # Note that batches have size 1!
        dataloader = DataLoader(IterableTimeSeries(self.global_state,
                                                   data_set,
                                                   mode="test"),
                                batch_size=1,
                                num_workers=self.global_state.num_workers,
                                drop_last=True)
        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" Dataloader length: {len(dataloader)}")

        return dataloader
Beispiel #2
0
 def __len__(self):
     """
     Total number of samples in the dataset
     """
     if self.global_state.debug and (self.__class__.__name__
                                     not in self.global_state.skip_debug):
         logging(
             f"{self.__class__.__name__}, "
             f"{inspect.currentframe().f_code.co_name}: "
             f" Call to __len__() on {self.data_type} returning"
             f" self.data.size()[0] - (self.global_state.n_model + 1) ="
             f" {self.data.size()[0] - (self.global_state.n_model + 1)}")
     return self.data.size()[0] - (self.global_state.n_model + 1)
Beispiel #3
0
    def build_optimizer(self, reload=False):
        if self.global_state.optim.lower() == "sgd":
            optimizer = optim.SGD(
                self.transformer_model.parameters(),
                lr=self.global_state.lr,
                momentum=self.global_state.mom,
            )
        elif self.global_state.optim.lower() == "adam":
            optimizer = optim.Adam(params=self.transformer_model.parameters(),
                                   lr=self.global_state.lr)

        elif self.global_state.optim.lower() == "adagrad":
            optimizer = optim.Adagrad(self.transformer_model.parameters(),
                                      lr=self.global_state.lr)
        else:
            raise ValueError(
                f"optimizer type {self.global_state.optim} not recognized")

        if reload:
            if self.global_state.restart_from is not None:
                optim_name = f"optimizer_{self.global_state.restart_from}.pt"
            else:
                optim_name = "optimizer.pt"

            optim_file_name = os.path.join(self.global_state.restart_dir,
                                           optim_name)
            logging(f"reloading {optim_file_name}")
            if os.path.exists(
                    os.path.join(self.global_state.restart_dir, optim_name)):
                with open(
                        os.path.join(self.global_state.restart_dir,
                                     optim_name), "rb") as optim_file:
                    opt_state_dict = torch.load(optim_file)
                    try:
                        optimizer.load_state_dict(opt_state_dict)

                    # in case the optimizer param groups aren't the same shape,
                    # merge them
                    except:
                        logging("merging optimizer param groups")
                        opt_state_dict["param_groups"][0]["params"] = [
                            param
                            for param_group in opt_state_dict["param_groups"]
                            for param in param_group["params"]
                        ]
                        opt_state_dict["param_groups"] = [
                            opt_state_dict["param_groups"][0]
                        ]
                        optimizer.load_state_dict(opt_state_dict)
            else:
                logging("Optimizer was not saved. Start from scratch.")

        return optimizer
Beispiel #4
0
    def __getitem__(self, index):
        # An item is a tuple of:
        #   - a transformer_model input being, say, 60 dates of time series
        #   -  the following date as expected output
        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" {self.data_type} \t item  no.: {index}")
            logging(
                f"{self.__class__.__name__}, "
                f"{inspect.currentframe().f_code.co_name}: "
                f"       x: from {index} to {index + self.global_state.n_model}"
            )
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f"       y: at {index + self.global_state.n_model}")

        return (self.data[index:index + self.global_state.n_model, :],
                self.data[index + self.global_state.n_model, :])
Beispiel #5
0
    def __init__(self,
                 global_state: GlobalState,
                 data,
                 mode="train",
                 debug=False):
        super(IterableTimeSeries, self).__init__()

        self.global_state = global_state
        self.data_type = mode

        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" Creating dataloader for data set: {mode}")

        # In debug mode, only use about 2 epoch of input
        # TODO refactor to use exactly 2 epoch instead of 700 dates.
        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            total_data_set_length = min(global_state.dataset_size,
                                        data.size(0))
        else:
            total_data_set_length = data.size(0)

        # The beginning of the data set is where 'train' starts
        # The end of the dataset is here we find the last testing data
        # We therefore start at 0
        # And end at total_data_set_length = n_samples + (n_model+1) + n_val + n_test
        # (a sample is n_model vectors for X and 1 vector for Y)
        # Final -1 is to reflect Python's 0-array convention
        self.n_samples = total_data_set_length - \
                         (global_state.n_model + 1) - \
                         global_state.n_val - \
                         global_state.n_test - \
                         1

        # Adjust the start of the dataset for training / val / test
        if mode == "train":
            start_index = 0
            end_index = (global_state.n_model + 1) + self.n_samples

        elif mode == "val":
            start_index = self.n_samples
            end_index = (global_state.n_model + 1) + self.n_samples + \
                        global_state.n_val

        elif mode == "test":
            start_index = self.n_samples + global_state.n_val
            end_index = (global_state.n_model + 1) + self.n_samples + \
                        global_state.n_val + \
                        global_state.n_test

        # This is the actual input on which to iterate
        self.data = data[start_index:end_index, :]

        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" Dataset {self.data_type} - Start index: {start_index}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" Dataset {self.data_type} - End index: {end_index}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" Dataset {self.data_type} - data: {self.data.size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" Dataset {self.data_type} - data set iterator"
                    f" length: {self.data.size()[0]}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" Dataset {self.data_type} - calculated"
                    f" n_samples: {self.n_samples}")

        # d_series is the depth of a series (how many input points per dates)
        # n_series is the number of series (how many dates)
        self.n_series, self.d_series = data.size()
Beispiel #6
0
    def validation_step(self, batch, batch_nb):
        # DIMS: batch = (x, y)
        # DIMS: x -> (n_batch, n_model, d_model)
        # DIMS: y -> (n_batch, d_model)
        x, y = batch

        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" x = batch[0]: {batch[0].size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" y = batch[1]: {batch[1].size()}")

        y_hat = self.forward(x, y)
        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" y_hat['loss']: {y_hat['loss'].size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" y_hat['layer_out']: {y_hat['layer_out'].size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" y_hat['memory'][0]: {y_hat['memory'][0].size()}")

        val_loss = self.loss_function(y_hat['layer_out'][:, -1, :], y)
        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" loss: {val_loss.size()}")

        val_loss = val_loss.unsqueeze(dim=-1)
        return {"val_loss": val_loss}
Beispiel #7
0
    def training_step(self,
                      batch: List[torch.Tensor],
                      batch_idx: int,
                      optimizer_idx: int = 1):
        # DIMS: batch = (x, y)
        # DIMS: x -> (n_batch, n_model, d_model)
        # DIMS: y -> (n_batch, d_model)
        x, y = batch

        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" x = batch[0]: {batch[0].size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" y = batch[1]: {batch[1].size()}")

        y_hat = self.forward(x, y)
        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" y_hat['loss']: {y_hat['loss'].size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" y_hat['layer_out']: {y_hat['layer_out'].size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" y_hat['memory'][0]: {y_hat['memory'][0].size()}")

        loss = self.loss_function(y_hat['layer_out'][:, -1, :], y)
        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" loss: {loss.size()}")

        loss = loss.unsqueeze(dim=-1)
        return {"loss": loss}
Beispiel #8
0
    def forward(self, input: torch.FloatTensor, output: torch.FloatTensor,
                *mems):
        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"")
            logging(f"")
            logging(
                f"########################################################")
            logging(f"")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" input: {input.size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" output: {output.size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" mems: {len(mems)}")

        return self.transformer_model(input, output, *mems)
Beispiel #9
0
    def __init__(self, global_state: GlobalState):
        super(TransformerXL_Trainer, self).__init__()

        self.global_state = global_state

        if self.global_state.debug and (self.__class__.__name__
                                        not in self.global_state.skip_debug):
            logging(f"")
            logging(f"")
            logging(
                f"########################################################"
                f"########################################################")
            logging(
                f"########################################################"
                f"########################################################")
            logging(f"")
            logging(f"    INITIALISING TRANSFORMER XL")
            logging(f"")
            logging(
                f"########################################################"
                f"########################################################")
            logging(
                f"########################################################"
                f"########################################################")
            logging(f"")

        self.transformer_model = Transformer_XL(
            n_layer=global_state.n_layer,
            d_hidden=global_state.d_hidden,
            d_pos_enc=global_state.d_pos_enc,
            n_head=global_state.n_head,
            d_head=global_state.d_head,
            d_FF_inner=global_state.d_FF_inner,
            d_model=global_state.d_model,
            dropout=global_state.dropout,
            dropout_attn=global_state.dropout_attn,
            n_model=global_state.n_model,
            n_mems=global_state.n_mems,
            debug=global_state.debug,
            skip_debug=global_state.skip_debug)

        self.loss_function = nn.MSELoss()
    def forward(
        self,
        data: torch.FloatTensor,
        target: torch.FloatTensor,  # (n_model, bs)
        memory: Optional[List[torch.FloatTensor]] = None,
    ) -> Dict[str, torch.Tensor]:
        # DIMS: data -> (n_batch, n_model, d_model)
        # DIMS: target -> (n_model, d_model)

        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" data: {data.size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" target: {target.size()}")

        if memory is None:
            memory: List[torch.FloatTensor] = self.init_memory(data.device)

        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" memory: {len(memory)}")

        assert len(memory) == len(self.layers) + 1

        n_batch, n_sequence, d_sequence = data.size()
        prev_seq = memory[0].size(0)

        # Construct attention mask
        dec_attn_mask = torch.triu(
            torch.ones((n_sequence, n_sequence + prev_seq)),
            diagonal=1 + prev_seq,
        ).bool()[..., None].to(data.device)

        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" dec_attn_mask: {dec_attn_mask.size()}")

        current_segment = self.dropout(data)

        pos_idxs = torch.arange(n_sequence + prev_seq - 1,
                                -1,
                                -1.0,
                                dtype=torch.float).to(current_segment.device)
        pos_embs = self.dropout(self.pos_enc(pos_idxs))

        # Main part of forward pass
        hidden_states = [current_segment]
        layer_out = current_segment
        for mem, layer in zip(memory, self.layers):
            layer_out = layer(layer_out,
                              pos_embs,
                              self.u,
                              self.v,
                              mask=dec_attn_mask,
                              mems=mem)
            hidden_states.append(layer_out)

        layer_out = self.dropout(layer_out)
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" layer_out: {layer_out.size()}")

        # loss = self.loss_fn(layer_out.view(-1, layer_out.size(-1)),
        #                     target.view(-1))
        loss = self.loss_fn(layer_out[:, -1, :], target)

        # Update memory
        # Ensure the memory is treated as a constant
        # and we do not back propagate through them
        new_memory = self.update_memory(memory, hidden_states)

        return {"loss": loss, "layer_out": layer_out, "memory": new_memory}
    def forward(
        self,
        segment: torch.FloatTensor,
        pos_encs: torch.FloatTensor,
        memories: torch.FloatTensor,
        u: torch.FloatTensor,
        v: torch.FloatTensor,
        mask: Optional[torch.FloatTensor] = None,
    ):
        """
        pos_encs: position encodings is separate to handle relative positions
        DIMS: segment -> (n_batch, n_model, d_model)
        DIMS: pos_embs -> (n_model + n_mems, self.d_input)
        DIMS: output ->  (n_model, self.d_input)
        DIMS: u ->  (n_head, d_head)
        """

        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" segment: {segment.size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" pos_encs: {pos_encs.size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" memories: {memories.size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" u: {u.size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" v: {v.size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" mask: {mask.size()}")

        # length of current segment
        n_batch, n_model, d_input = segment.shape

        # length of memory available
        n_current_mems = memories.shape[0]

        n_head, d_head = self.n_head, self.d_head

        # DIMS: memory_cat_input -> (n_current_mems + n_model, d_input)
        memory_cat_input = torch.cat([memories, segment], dim=1)
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" memory_cat_input: {memory_cat_input.size()}")

        # DIMS: segment -> (d_batch, n_model, d_input)
        # DIMS: self.linear_q -> (d_input, n_head * d_head)
        # DIMS: queries -> (d_batch, n_model, b, n_head * d_head)
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" linear_q: in {self.linear_q.in_features} x "
                    f"out {self.linear_q.out_features}")

        queries = self.linear_q(segment)

        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" queries: {queries.size()}")

        # DIMS: memory_cat_input -> (n_model + n_current_mems, d_input)
        # DIMS: self.linear_kv -> (d_input, d_head * n_head * 2)
        # DIMS: keys -> (d_batch, n_model + n_current_mems, d_head * n_head)
        # DIMS: values -> (d_batch, n_model + n_current_mems, d_head * n_head)
        keys, values = torch.chunk(self.linear_kv(memory_cat_input), 2, dim=-1)
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" keys: {keys.size()}")
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" values: {values.size()}")

        # DIMS: queries -> (d_batch, n_model, b, n_head * d_head)
        # DIMS: u ->  (n_head, d_head)
        # DIMS: content_attn -> (n_model, n_model + n_current_mems, n_head)
        content_attn = torch.einsum(
            "bihd,bjhd->bijh",
            ((queries.view(n_batch, n_model, n_head, d_head) + u),
             keys.view(n_batch, n_model + n_current_mems, n_head, d_head)))
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" content_attn: {content_attn.size()}")

        # position-based attention term ((b) + (d) in the paper)
        # this attention is solely based on the position of the key/values
        # (i.e. it does not take the content of the key/values into account)

        # DIMS: pos_enc -> (n_model, d_pos_enc)
        # DIMS: self.linear_p -> (d_pos_enc, d_head * n_head)
        # DIMS: positions -> (n_model, d_head * n_head)
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" linear_p: in {self.linear_p.in_features} x "
                    f"out {self.linear_p.out_features}")

        positions = self.linear_p(pos_encs)
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" positions: {positions.size()}")

        # DIMS: position_attn -> (n_model, n_model + n_current_mems, n_head)
        position_attn = torch.einsum(
            "bihd,jhd->bijh",
            ((queries.view(n_batch, n_model, n_head, d_head) + v),
             positions.view(n_model + n_current_mems, n_head, d_head)))
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" position_attn: {position_attn.size()}")

        # Compute positional attention efficiently
        # DIMS: position_attn -> (n_model, n_model + n_current_mems, n_head)
        position_attn = self._rel_shift(position_attn)

        # the attention is the sum of content-based and position-based attention
        # DIMS: attn -> (n_batch, n_model, n_model + n_current_mems, n_head)
        attn = content_attn + position_attn
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" attn: {attn.size()}")

        if mask is not None and mask.any().item():
            padded_mask = mask[None, :, :, :]
            if self.debug and (self.__class__.__name__ not in self.skip_debug):
                logging(f"{self.__class__.__name__}, "
                        f"{inspect.currentframe().f_code.co_name}: "
                        f" padded_mask: {padded_mask.size()}")

            attn = attn.masked_fill(padded_mask, -float('inf'))
            if self.debug and (self.__class__.__name__ not in self.skip_debug):
                logging(f"{self.__class__.__name__}, "
                        f"{inspect.currentframe().f_code.co_name}: "
                        f" attn: {attn.size()}")

        # rescale to prevent values from exploding
        # normalize across the value sequence dimension
        # TODO change softmax
        attn = torch.softmax(attn * self.scale, dim=1)
        attn = self.dropout_attn(attn)

        # DIMS: attn -> (n_model, n_model + n_current_mems, n_head)
        # DIMS: values -> (n_model + n_current_mems, d_head * n_head)
        # DIMS: values.view -> (n_model + n_current_mems, n_head, d_head)
        # i: n_model
        # j: n_model + n_current_mems
        # h: n_head
        # d: d_head
        # DIMS: einsum -> (n_model, n_head, d_head)
        # DIMS: attn_weighted_values -> (n_model, n_head* d_head)
        attn_weighted_values = torch.einsum(
            "bijh,bjhd->bihd",
            (attn,
             values.view(n_batch, n_model + n_current_mems, n_head, d_head),)) \
            .contiguous() \
            .view(n_batch, n_model, n_head * d_head)
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" attn_weighted_values: {attn_weighted_values.size()}")

        # Project back to input dimension and add residual connection
        # DIMS: self.layer_out() -> (d_head * n_head, d_output)
        # DIMS: attn_weighted_values -> (n_model, n_head* d_head)
        # DIMS: self_dropout(...) -> (n_model, d_output)
        # DIMS: segment -> (n_model, self.d_input)
        output = segment + self.dropout(self.layer_out(attn_weighted_values))
        output = self.norm_out(output)
        if self.debug and (self.__class__.__name__ not in self.skip_debug):
            logging(f"{self.__class__.__name__}, "
                    f"{inspect.currentframe().f_code.co_name}: "
                    f" output: {output.size()}")

        return output