Exemple #1
0
class TensorboardLogger(BaseLogger):
    def open(self, root_dir: str):
        from pathlib import Path
        log_dir = Path(root_dir, 'tb_logs')
        self.logger = SummaryWriter(str(log_dir))

    def close(self):
        self.logger.close()

    def log_scalar(self, name: str, value: float, step: Optional[int] = None):
        self.logger.add_scalar(name, value, step)

    def log_scalars(self,
                    name: str,
                    scalar_dict: Dict[str, float],
                    step: Optional[int] = None):
        self.logger.add_scalars(name, scalar_dict, step)

    def log_model(self, model: ITrainableModel, device: str):
        inputs = model.example_inputs()
        if inputs is None:
            pass
        elif isinstance(inputs, torch.Tensor):
            inputs = inputs.to(device)
        elif isinstance(inputs, (list, tuple)):
            inputs = [x.to(device) for x in inputs]
        elif isinstance(inputs, dict):
            inputs = {k: v.to(device) for k, v in inputs.items()}
        else:
            raise RuntimeError('Unsupported example_inputs')
        self.logger.add_graph(model, input_to_model=inputs)

    def log_info(self, info: str):
        pass
class SeqClassificationModel(Model):
    """
    Question answering model where answers are sentences
    """
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        use_sep: bool = True,
        with_crf: bool = False,
        self_attn: Seq2SeqEncoder = None,
        bert_dropout: float = 0.1,
        sci_sum: bool = False,
        additional_feature_size: int = 0,
    ) -> None:
        super(SeqClassificationModel, self).__init__(vocab)

        self.track_embedding_list = []
        self.track_embedding = {}
        self.text_field_embedder = text_field_embedder
        self.vocab = vocab
        self.use_sep = use_sep
        self.with_crf = with_crf
        self.sci_sum = sci_sum
        self.self_attn = self_attn
        self.additional_feature_size = additional_feature_size

        self.dropout = torch.nn.Dropout(p=bert_dropout)

        # define loss
        if self.sci_sum:
            self.loss = torch.nn.MSELoss(
                reduction='none')  # labels are rouge scores
            self.labels_are_scores = True
            self.num_labels = 1
        else:
            self.loss = torch.nn.CrossEntropyLoss(ignore_index=-1,
                                                  reduction='none')
            self.labels_are_scores = False
            self.num_labels = self.vocab.get_vocab_size(namespace='labels')
            # define accuracy metrics
            self.label_accuracy = CategoricalAccuracy()
            self.label_f1_metrics = {}

            # define F1 metrics per label
            for label_index in range(self.num_labels):
                label_name = self.vocab.get_token_from_index(
                    namespace='labels', index=label_index)
                self.label_f1_metrics[label_name] = F1Measure(label_index)

        encoded_senetence_dim = text_field_embedder._token_embedders[
            'bert'].output_dim

        ff_in_dim = encoded_senetence_dim if self.use_sep else self_attn.get_output_dim(
        )
        ff_in_dim += self.additional_feature_size

        self.time_distributed_aggregate_feedforward = TimeDistributed(
            Linear(ff_in_dim, self.num_labels))

        if self.with_crf:
            self.crf = ConditionalRandomField(
                self.num_labels,
                constraints=None,
                include_start_end_transitions=True)
        self.track_embedding["init_info"] = {
            "ff_in_dim": ff_in_dim,
            "encoded_sentence_dim": encoded_senetence_dim,
            "sci_sum": self.sci_sum,
            "use_sep": self.use_sep,
            "with_crf": self.with_crf,
            "additional_feature_size": self.additional_feature_size
        }
        self.t_board_writer = SummaryWriter()
        self.t_board_writer.add_graph(self)

    def forward(
        self,  # type: ignore
        sentences: torch.LongTensor,
        labels: torch.IntTensor = None,
        confidences: torch.Tensor = None,
        additional_features: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        """
        Parameters
        ----------
        TODO: add description

        Returns
        -------
        An output dictionary consisting of:
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        # ===========================================================================================================
        # Layer 1: For each sentence, participant pair: create a Glove embedding for each token
        # Input: sentences
        # Output: embedded_sentences
        print(sentences)
        sentences_conv = {}
        for key, val in sentences_conv.items():
            sentences_conv[key] = val.cpu().data.numpy().tolist()
        self.track_embedding["Transformation_0"] = {
            "sentences": sentences_conv
        }
        # embedded_sentences: batch_size, num_sentences, sentence_length, embedding_size
        embedded_sentences = self.text_field_embedder(sentences)
        self.track_embedding["Transformation_1"] = {
            "size": list(embedded_sentences.size()),
            "dim": embedded_sentences.dim()
        }

        # Kacper: Basically a padding mask for bert
        mask = get_text_field_mask(sentences, num_wrapping_dims=1).float()
        batch_size, num_sentences, _, _ = list(embedded_sentences.size())

        if self.use_sep:
            # The following code collects vectors of the SEP tokens from all the examples in the batch,
            # and arrange them in one list. It does the same for the labels and confidences.
            # TODO: replace 103 with '[SEP]'
            # Kacper: This is an important step where we get SEP tokens to later do sentence classification
            # Kacper: We take a location of SEP tokens from the sentences to get a mask
            sentences_mask = sentences[
                'bert'] == 103  # mask for all the SEP tokens in the batch
            # Kacper: We use this mask to get the respective embeddings from the output layer of bert
            embedded_sentences = embedded_sentences[
                sentences_mask]  # given batch_size x num_sentences_per_example x sent_len x vector_len
            # returns num_sentences_per_batch x vector_len
            self.track_embedding["Transformation_2"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: I dont get it why it became 2 instead of 4? What is the difference between size() and dim()???
            assert embedded_sentences.dim() == 2
            num_sentences = embedded_sentences.shape[0]
            # Kacper: comment below is vague
            # Kacper: I think we batch in one array because we just need to compute a mean loss from all of them
            # for the rest of the code in this model to work, think of the data we have as one example
            # with so many sentences and a batch of size 1
            batch_size = 1
            embedded_sentences = embedded_sentences.unsqueeze(
                dim=0)  # Kacper: We batch all sentences in one array
            self.track_embedding["Transformation_3"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: Dropout layer is between filtered embeddings and linear layer
            embedded_sentences = self.dropout(embedded_sentences)
            self.track_embedding["Transformation_4"] = {
                "size": list(embedded_sentences.size()),
                "dim": embedded_sentences.dim()
            }
            # Kacper: we provide the labels for training (for each sentence)
            if labels is not None:
                if self.labels_are_scores:
                    labels_mask = labels != 0.0  # mask for all the labels in the batch (no padding)
                else:
                    labels_mask = labels != -1  # mask for all the labels in the batch (no padding)

                labels = labels[
                    labels_mask]  # given batch_size x num_sentences_per_example return num_sentences_per_batch
                assert labels.dim() == 1
                if confidences is not None:
                    confidences = confidences[labels_mask]
                    assert confidences.dim() == 1
                if additional_features is not None:
                    additional_features = additional_features[labels_mask]
                    assert additional_features.dim() == 2

                num_labels = labels.shape[0]
                # Kacper: this might be useful to consider in my code as well
                if num_labels != num_sentences:  # bert truncates long sentences, so some of the SEP tokens might be gone
                    assert num_labels > num_sentences  # but `num_labels` should be at least greater than `num_sentences`
                    logger.warning(
                        f'Found {num_labels} labels but {num_sentences} sentences'
                    )
                    labels = labels[:
                                    num_sentences]  # Ignore some labels. This is ok for training but bad for testing.
                    # We are ignoring this problem for now.
                    # TODO: fix, at least for testing

                # do the same for `confidences`
                if confidences is not None:
                    num_confidences = confidences.shape[0]
                    if num_confidences != num_sentences:
                        assert num_confidences > num_sentences
                        confidences = confidences[:num_sentences]

                # and for `additional_features`
                if additional_features is not None:
                    num_additional_features = additional_features.shape[0]
                    if num_additional_features != num_sentences:
                        assert num_additional_features > num_sentences
                        additional_features = additional_features[:
                                                                  num_sentences]

                # similar to `embedded_sentences`, add an additional dimension that corresponds to batch_size=1
                labels = labels.unsqueeze(dim=0)
                if confidences is not None:
                    confidences = confidences.unsqueeze(dim=0)
                if additional_features is not None:
                    additional_features = additional_features.unsqueeze(dim=0)
        else:
            # ['CLS'] token
            # Kacper: this shouldnt be the case for our project
            embedded_sentences = embedded_sentences[:, :, 0, :]
            embedded_sentences = self.dropout(embedded_sentences)
            batch_size, num_sentences, _ = list(embedded_sentences.size())
            sent_mask = (mask.sum(dim=2) != 0)
            embedded_sentences = self.self_attn(embedded_sentences, sent_mask)

        if additional_features is not None:
            embedded_sentences = torch.cat(
                (embedded_sentences, additional_features), dim=-1)

        # Kacper: we unwrap the time dimension of a tensor into the 1st dimension (batch),
        # Kacper: apply a linear layer and wrap the the time dimension back
        # Kacper: I would suspect it is happening only for embeddings related to the [SEP] tokens
        label_logits = self.time_distributed_aggregate_feedforward(
            embedded_sentences)
        # label_logits: batch_size, num_sentences, num_labels
        self.track_embedding["logits"] = {
            "size": list(label_logits.size()),
            "dim": label_logits.dim()
        }
        #print(self.track_embedding)
        self.track_embedding_list.append(deepcopy(self.track_embedding))
        with open(path_json, 'w') as json_out:
            json.dump(self.track_embedding_list, json_out)

        if self.labels_are_scores:
            label_probs = label_logits
        else:
            label_probs = torch.nn.functional.softmax(label_logits, dim=-1)

        # Create output dictionary for the trainer
        # Compute loss and epoch metrics
        output_dict = {"action_probs": label_probs}

        # =====================================================================

        if self.with_crf:
            # Layer 4 = CRF layer across labels of sentences in an abstract
            mask_sentences = (labels != -1)
            best_paths = self.crf.viterbi_tags(label_logits, mask_sentences)
            #
            # # Just get the tags and ignore the score.
            predicted_labels = [x for x, y in best_paths]
            # print(f"len(predicted_labels):{len(predicted_labels)}, (predicted_labels):{predicted_labels}")

            label_loss = 0.0
        if labels is not None:
            # Compute cross entropy loss
            # Kacper: reshape logits to be of the following shape in view()
            flattened_logits = label_logits.view((batch_size * num_sentences),
                                                 self.num_labels)
            # Make labels to be contiguous in memory, reshape it so it is in a one dimension
            flattened_gold = labels.contiguous().view(
                -1)  # Kacper: True labels

            if not self.with_crf:
                # Kacper: We are only interested in this part of the code since we don't use crf
                # Kacper: Get a loss (MSE if sci_sum is True or Crossentropy)
                label_loss = self.loss(flattened_logits.squeeze(),
                                       flattened_gold)
                if confidences is not None:
                    label_loss = label_loss * confidences.type_as(
                        label_loss).view(-1)
                label_loss = label_loss.mean()  # Kacper: Get a mean loss
                # Kacper: Get a probabilities from the logits
                flattened_probs = torch.softmax(flattened_logits, dim=-1)
            else:
                # Kacper: We are not interested in this if statement branch (for our project)
                clamped_labels = torch.clamp(labels, min=0)
                log_likelihood = self.crf(label_logits, clamped_labels,
                                          mask_sentences)
                label_loss = -log_likelihood
                # compute categorical accuracy
                crf_label_probs = label_logits * 0.
                for i, instance_labels in enumerate(predicted_labels):
                    for j, label_id in enumerate(instance_labels):
                        crf_label_probs[i, j, label_id] = 1
                flattened_probs = crf_label_probs.view(
                    (batch_size * num_sentences), self.num_labels)

            if not self.labels_are_scores:
                # Kacper: this will be a case for us as well because labels are numerical for Pubmed data
                evaluation_mask = (flattened_gold != -1)
                # Kacper: CategoricalAccuracy is computed in this case
                self.label_accuracy(flattened_probs.float().contiguous(),
                                    flattened_gold.squeeze(-1),
                                    mask=evaluation_mask)

                # compute F1 per label
                for label_index in range(self.num_labels):
                    label_name = self.vocab.get_token_from_index(
                        namespace='labels', index=label_index)
                    metric = self.label_f1_metrics[label_name]
                    metric(flattened_probs,
                           flattened_gold,
                           mask=evaluation_mask)

        if labels is not None:
            output_dict["loss"] = label_loss
        output_dict['action_logits'] = label_logits
        return output_dict

    def get_metrics(self, reset: bool = False):
        # Kacper: this function has to implemented due to API requirements for AllenNLP
        # Kacper: so it can be run automatically with a config file
        metric_dict = {}

        if not self.labels_are_scores:
            type_accuracy = self.label_accuracy.get_metric(reset)
            metric_dict['acc'] = type_accuracy

            average_F1 = 0.0
            for name, metric in self.label_f1_metrics.items():
                metric_val = metric.get_metric(reset)
                metric_dict[name + 'F'] = metric_val[2]
                average_F1 += metric_val[2]

            average_F1 /= len(self.label_f1_metrics.items())
            metric_dict['avgF'] = average_F1

        return metric_dict
Exemple #3
0
def train(opts: Settings):
    # === INSTANTIATE ENVIRONMENT ===
    # gym.make() but with imports configured as specified in arg.
    _gym_make = partial(gym_make, opts.imports)
    subproc_gym_make = subproc(_gym_make)

    # If `opts.subproc==True`, invoke gym.make() in a subprocess,
    # and treat the resultant instance as a `gym.Env`.
    make_env = subproc_gym_make if opts.subproc else _gym_make

    def get_env(index: int):
        env = make_env(opts.env_id)
        env.seed(index)
        env.reset()
        return env

    env = MultiEnv(get_env, opts.num_envs)
    entry_type = [
        ('state', env.observation_space.dtype, env.observation_space.shape),
        ('action', env.action_space.dtype, env.action_space.shape),
        ('reward', np.float32, (1, )),
        # ('state1', env.observation_space.dtype, env.observation_space.shape),
        ('done', np.bool, (1, )),
        ('value', np.float32, (1, )),
        ('log_prob', np.float32, env.action_space.shape)
    ]

    # === NORMALIZERS FOR INPUTS ===
    reward_normalizer = ExponentialMovingGaussian(
        alpha=opts.reward_normalizer_alpha)
    state_normalizer = ExponentialMovingGaussian(
        alpha=opts.state_normalizer_alpha)

    # === INSTANTIATE MEMORY ===
    memory = ContiguousRingBuffer(capacity=opts.update_steps,
                                  dims=(opts.num_envs, ),
                                  dtype=entry_type)

    # === INSTANTIATE POLICY ===
    # FIXME(ycho): Instead of assuming 1D box spaces,
    # explicitly wrap envs with flatten()...
    device = th.device(opts.device)
    policy = AC(env.observation_space.shape[0], env.action_space.shape[0],
                opts.ac).to(device)

    # === INSTANTIATE AGENT ===
    ppo = PPO(policy, memory, device, opts.ppo)

    # === TRAIN ===
    states = env.reset()
    dones = np.full((opts.num_envs, 1), False, dtype=np.bool)
    returns = np.zeros(opts.num_envs, dtype=np.float32)

    # === LOGGER ===
    # TODO(ycho): Configure logger
    writer = SummaryWriter()
    writer.add_graph(policy, th.as_tensor(states).float().to(device))

    # === CALLBACKS ===
    save_cb = SaveCallback(
        opts.save_steps, opts.ckpt_path, lambda: {
            'settings': opts,
            'state_dict': policy.state_dict(),
            'reward_normalizer': reward_normalizer.params(),
            'state_normalizer': state_normalizer.params()
        })

    # === VARIABLES FOR DEBUGGING / LOG TRACKING ===
    reset_count = 0
    start_time = time.time()

    # === START TRAINING ===
    step = 0
    while step < opts.max_steps:
        # Reset any env that has reached termination.
        # FIXME(ycho): assumes isinstance(env, MultiEnv), of course.
        for i in range(opts.num_envs):
            if not dones[i]:
                continue
            states[i][:] = env.envs[i].reset()
            returns[i] = 0.0
            reset_count += 1

        # NOTE(ycho): Workaround for the current limitation of `MultiEnv`.
        # action = [env.action_space.sample() for _ in range(opts.num_envs)]
        # sanitize `states` arg.
        states = np.asarray(states).astype(np.float32)

        # Add states to stats for normalization.
        for s in states:
            state_normalizer.add(s)

        # Normalize states in-place.
        states = state_normalizer.normalize(states)
        states = np.clip(states, -10.0, 10.0)  # clip to +-10 stddev

        with th.no_grad():
            action, value, log_prob = ppo.act(states, True)

        # NOTE(ycho): Clip action within valid domain...
        clipped_action = np.clip(action, env.action_space.low,
                                 env.action_space.high)

        # Step according to above action.
        out = env.step(clipped_action)

        # Format entry.
        nxt_states, rewards, dones, _ = out

        # Add rewards to stats for normalization.
        # returns[np.asarray(dones).reshape(-1).astype(np.bool)] = 0.0
        returns = returns * opts.gae.gamma + np.reshape(rewards, -1)
        # NOTE(ycho): collect stats on `returns` instead of `rewards`.
        # for r in rewards:
        #    reward_normalizer.add(r)
        for r in returns:
            reward_normalizer.add(r)

        # Train if buffer full ...
        if memory.is_full:
            writer.add_scalar('reward_mean',
                              reward_normalizer.mean,
                              global_step=step)
            writer.add_scalar('reward_var',
                              reward_normalizer.var,
                              global_step=step)
            writer.add_scalar('log_std',
                              policy.log_std.detach().cpu().numpy()[0],
                              global_step=step)
            writer.add_scalar('fps',
                              step / (time.time() - start_time),
                              global_step=step)

            # NOTE(ycho): Don't rely on printed reward stats for tracking
            # training progress ... use tensorboard instead.
            print('== step {} =='.format(step))
            # Log reward before overwriting with normalized values.
            print('rew = mean {} min {} max {} std {}'.format(
                memory['reward'].mean(), memory['reward'].min(),
                memory['reward'].max(), memory['reward'].std()))
            # print('rm {} rv {}'.format(reward_normalizer.mean,
            #                           reward_normalizer.var))

            # NOTE(ycho): States have already been normalized,
            # since those states were utilized as input for PPO action.
            # After that, the normalized states were inserted in memory.
            # memory['state'] = state_normalizer.normalize(memory['state'])

            # NOTE(ycho): I think it's fine to delay reward normalization to this point.
            # memory['reward'] = reward_normalizer.normalize(memory['reward'])
            # NOTE(ycho): maybe the proper thing to do is:
            # memory['reward'] = (memory['reward'] - reward_normalizer.mean) / np.sqrt(return_normalizer.var)
            memory['reward'] /= np.sqrt(reward_normalizer.var)
            memory['reward'] = np.clip(memory['reward'], -10.0, 10.0)

            # Create training data slices from memory ...
            dones = np.asarray(dones).reshape(opts.num_envs, 1)
            advs, rets = gae(memory, value, dones, opts.gae)
            # print('std = {}'.format(ppo.policy.log_std.exp()))

            ucount = 0
            info = None
            for _ in range(opts.num_epochs):
                for i in range(0, len(memory), opts.batch_size):
                    # Prepare current minibatch dataset ...
                    exp = memory[i:i + opts.batch_size]
                    act = exp['action']
                    obs = exp['state']
                    old_lp = exp['log_prob']
                    # old_v = exp['value'] # NOTE(ycho): unused
                    adv = advs[i:i + opts.batch_size]
                    ret = rets[i:i + opts.batch_size]

                    # Evaluate what had been done ...
                    # NOTE(ycho): wouldn't new_v == old_v
                    # and new_lp == old_lp for the very first one in the batch??
                    # hmm ....
                    new_v, new_lp, entropy = ppo.evaluate(
                        obs.copy(), act.copy())

                    info_i = {}
                    loss = ppo.compute_loss(obs.copy(), act.copy(),
                                            old_lp.copy(), new_v, new_lp,
                                            entropy, adv, ret, info_i)

                    # NOTE(ycho): Below, only required for logging
                    if True:
                        with th.no_grad():
                            if info is None:
                                info = info_i
                            else:
                                for k in info.keys():
                                    info[k] += info_i[k]
                        ucount += 1

                    # Optimization step
                    ppo.optimizer.zero_grad()
                    loss.backward()
                    # Clip grad norm
                    th.nn.utils.clip_grad_norm_(ppo.policy.parameters(),
                                                opts.ppo.max_grad_norm)
                    ppo.optimizer.step()

            for k, v in info.items():
                writer.add_scalar(k,
                                  v.detach().cpu().numpy() / ucount,
                                  global_step=step)

            # Empty the memory !
            memory.reset()

        # Append to memory.
        entry = list(
            zip(*(
                states,
                action,
                rewards,
                # nxt_states,
                dones,
                value,
                log_prob)))
        memory.append(entry)

        # Cache `states`, update steps and continue.
        states = nxt_states
        step += opts.num_envs

        save_cb.on_step(step)

    writer.close()

    # Save ...
    th.save(
        {
            'settings': opts,
            'state_dict': policy.state_dict(),
            'reward_normalizer': reward_normalizer.params(),
            'state_normalizer': state_normalizer.params()
        }, opts.model_path)
Exemple #4
0
def train(load_path=None):
    # Define exp name
    exp_name = input("Experience Name (You may leave this blank): ")
    exp_name_ = datetime.now().strftime("%m-%d-%H-%M") + (
        '-' + exp_name) if exp_name else ''

    # Prepare status save location
    if not os.path.exists(save_base_path):
        os.mkdir(save_base_path)

    if not os.path.exists(os.path.join(save_base_path, identifier)):
        os.mkdir(os.path.join(save_base_path, identifier))

    save_path = os.path.join(save_base_path, identifier, exp_name_)
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    # Save constants.py as json file
    with open(os.path.join(save_path, 'config.json'), 'w') as f:
        json.dump(C, f, default=lambda x: str(x), indent=True)

    # Load dipole kernel
    d = load_dipole().to(device)

    # Prepare training and validation datasets
    train_dataset = QSMDataset('train')
    b_mean = train_dataset.X_mean
    b_std = train_dataset.X_std
    y_mean = train_dataset.Y_mean
    y_std = train_dataset.Y_std

    train_loader = DataLoader(train_dataset,
                              num_workers=4,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True)
    val_loader = DataLoader(QSMDataset('val'), batch_size=1, shuffle=False)

    # Initialize training elements
    if load_path:
        checkpoint = torch.load(load_path, map_location='cpu')
        start_epoch = checkpoint.get('epoch')
        network = checkpoint.get('model')
        optimizer = checkpoint.get('optimizer')
        del checkpoint

        network = network.to(device)
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)
    else:
        start_epoch = 0
        network = QSMNet().to(device)
        optimizer = optim.Adam(network.parameters(), lr=learning_rate)

    # Initialize tensorboard logging
    writer = SummaryWriter(comment='_' + identifier + '_' + exp_name)
    writer.add_graph(network, torch.zeros(1, 1, 64, 64, 64, device=device))

    # Announce training
    print('{} training network "{}" for {} epoches on {}{}.'.format(
        'Resume' if load_path else 'Begin', identifier, train_epochs,
        device.type,
        ':' + str(device.index) if device.index is not None else ''))

    # Begin training
    step = start_epoch * len(train_loader)  # Global step count for tensorboard
    for epoch in range(start_epoch, start_epoch + train_epochs):

        network.train()
        loss_history = []

        if epoch < one_to_two - 1:
            print("Loss gear phase one")
            total_loss.phase = 'one'
        else:
            print("Loss gear phase two")
            total_loss.phase = 'two'

        for batch, (b, y, m) in enumerate(train_loader):
            b, y, m = b.to(device), y.to(device), m.to(device)

            optimizer.zero_grad()

            chi = network(b)

            loss, loss_parts = total_loss(chi, y, b, d, m, b_mean, b_std,
                                          y_mean, y_std)
            loss.backward()
            step += 1
            optimizer.step()

            loss_cpu = loss.detach().cpu().item()
            loss_history.append(loss_cpu)

            torch.cuda.empty_cache()

            with torch.no_grad():
                # Track loss with tensorboard
                writer.add_scalar('Loss/Total loss', loss_cpu, step)
                for loss_name, loss_value in loss_parts.items():
                    writer.add_scalar('Loss/' + loss_name, loss_value, step)

                # Print step status
                if (batch + 1) % print_every == 0:
                    train_log = 'Epoch {:2d}/{:2d}\tLoss: {:.6f}\tTrain: [{}/{} ({:.0f}%)]        '.format(
                        epoch + 1,
                        start_epoch + train_epochs, loss_cpu, batch + 1,
                        len(train_loader), 100. * batch / len(train_loader))
                    print(train_log, end='\r')
                    sys.stdout.flush()

        with torch.no_grad():
            # Print epoch status
            last_epoch = epoch + 1
            average_loss = sum(loss_history) / len(loss_history)
            epoch_end_log = "Epoch {:02d} completed, Average Loss is {:.6f}.            ".format(
                last_epoch, average_loss)
            print(epoch_end_log)

            # Save training status
            if last_epoch % save_every == 0:
                ckpt_save_path = os.path.join(
                    save_path,
                    '{}-epoch{:02d}.ckpt'.format(identifier, last_epoch))
                torch.save(
                    {
                        'epoch': last_epoch,
                        'model': network,
                        'optimizer': optimizer,
                    }, ckpt_save_path)
                print("Saved Model on epoch {}.".format(last_epoch))

        # Evaluate model
        #evaluate_model('validation', network, val_loader, last_epoch, save_path, writer)
        challenge_evaluation(exp_name_, network, writer, last_epoch)

    writer.close()
Exemple #5
0
class LMTrainer:
    def __init__(self, model, dataset, checkpoint_filename):
        self.checkpoint_filename = checkpoint_filename
        self.model = model.cuda()
        self.dataset = dataset

        self.writer = SummaryWriter()
        self.global_step = 0
        self.last_log = datetime.datetime.now()
        self.last_valid = datetime.datetime.now()

        # Dataset
        train_dataset, test_dataset = random_split(dataset, [90, 10])

        self.loaders = {
            "train":
            DataLoader(
                train_dataset,
                batch_size=config.BATCH_SIZE,
                shuffle=True,
                drop_last=True,
            ),
            "test":
            DataLoader(test_dataset,
                       batch_size=config.BATCH_SIZE,
                       shuffle=True,
                       drop_last=True),
        }
        inputs = next(iter(self.loaders["test"]))
        model(inputs.cuda())
        self.writer.add_graph(model, inputs.cuda())

        # Optimizer
        self.optim = torch.optim.lr_scheduler.CyclicLR(
            torch.optim.SGD(self.model.parameters(),
                            lr=config.LEARNING_RATE,
                            momentum=0.9),
            base_lr=0,
            max_lr=config.LEARNING_RATE,
        )

        self.best_valid = 1e9
        self.patience = 0

    def run_batch_learn(self, i, batch):
        self.optim.optimizer.zero_grad()
        self.model.train()
        loss, metrics = self.compute_loss_and_metrics_from_indices(batch)
        loss.backward()
        self.optim.optimizer.step()
        self.optim.step()
        self.writer.add_scalar(
            "train/lr",
            torch.Tensor([self.optim.get_lr()]),
            global_step=self.global_step,
        )
        self.global_step += config.BATCH_SIZE

    def evaluate_batch(self, mode, batch):
        loss, metrics = self.compute_loss_and_metrics_from_indices(batch)
        for name, metric in sorted(metrics.items()):
            scalar_name = f"{mode}/{name}"
            self.writer.add_scalar(scalar_name,
                                   metric,
                                   global_step=self.global_step)
        return loss

    def evaluate(self, mode):
        total_loss = 0
        with torch.no_grad():
            self.model.eval()
            loader = self.loaders[mode]
            for batch in tqdm.tqdm(loader, total=len(loader), desc=mode):
                batch = batch.cuda()
                loss = self.evaluate_batch(mode, batch)
                total_loss += loss.item()
        return total_loss

    def run_epoch(self):
        loader = self.loaders["train"]
        for i, batch in tqdm.tqdm(enumerate(loader),
                                  total=len(loader),
                                  desc="train"):
            batch = batch.cuda()
            self.run_batch_learn(i, batch)
            now = datetime.datetime.now()
            if now - self.last_log > config.LOG_INTERVAL:
                self.last_log = now
                self.evaluate_batch("train", batch)

        if now - self.last_valid > config.TEST_INTERVAL:
            valid_loss = self.evaluate("test")
            self.last_valid = now

            if valid_loss < self.best_valid:
                torch.save(self.model, self.checkpoint_filename)
                self.best_valid = valid_loss
                self.patience = 0
            else:
                self.patience += 1
                msg = f"Did not improve valid loss, model not saved (patience: {self.patience})"
                logging.info(msg)

    def compute_loss_and_metrics_from_indices(self, indices):
        attn = self.model.forward(indices)  # attn: [B, C, D]
        decoded = self.model.decoder(attn)  # decoded: [B, C, V]
        B, C, V = decoded.shape
        x = (decoded[:, :-1, :].contiguous().view(-1, decoded.size(-1))
             )  # [B * (C - 1), V]
        logits = F.log_softmax(x, dim=-1)  # [B * (C - 1), V]
        next_words = indices[:, 1:].contiguous().view(-1)  # [B * (C - 1)]
        prediction_loss = F.nll_loss(logits, next_words,
                                     reduction="sum") / attn.size(0)
        loss_words = prediction_loss / config.CONTEXT_SIZE
        _, topk = logits.topk(10, -1)
        topks = {
            f"P_{k}":
            (topk[:, :k] == next_words.view(-1, 1)).sum(-1).float().mean() *
            100
            for k in [1, 5, 10]
        }

        losses = {
            "loss/context": prediction_loss,
            "loss/words": loss_words,
            **topks
        }
        loss = prediction_loss
        if random.random() < 0.01:
            plt.clf()
            values = torch.exp(logits.reshape(B, C - 1, V)[0, -1, :])
            plt.plot(values.detach().cpu().numpy())
            mu = 500
            sigma = 100
            x = np.linspace(mu - 5 * sigma, mu + 5 * sigma, 100)
            plt.plot(x, scipy.stats.norm.pdf(x, mu, sigma))
            mu = 500
            sigma = 10
            x = np.linspace(mu - 5 * sigma, mu + 5 * sigma, 100)
            plt.plot(x, scipy.stats.norm.pdf(x, mu, sigma))
            global INDEX
            filename = f"figures/out{INDEX}.png"
            plt.savefig(filename)
            INDEX += 1
            print(f"Saved {filename}")
        return loss, losses