Exemple #1
0
    def train_one_epoch(self):
        num_batches = self.train_len // self.config.batch_size
        tqdm_batch = tqdm(self.train_loader,
                          total=num_batches,
                          desc="[Epoch {}]".format(self.current_epoch),
                          disable=self.disable_progressbar)

        val_every = max(num_batches // self.config['validations_per_epoch'], 1)
        self.model.train()

        epoch_loss = AverageMeter()
        epoch_acc = AverageMeter()

        for batch_i, data_list in enumerate(tqdm_batch):
            program_args, rvAssignments, _, rvOrders, rvOrders_lengths = \
                data_list[:-4], data_list[-4], data_list[-3], data_list[-2], data_list[-1]

            for i in range(len(program_args)):
                program_args[i] = program_args[i].to(self.device)

            rvAssignments = rvAssignments.to(self.device)
            rvOrders = rvOrders.to(self.device)
            rvOrders_lengths = rvOrders_lengths.to(self.device)

            labels = self._createLabels(rvOrders, rvAssignments,
                                        rvOrders_lengths)

            self.optim.zero_grad()

            outputs = self.model(program_args)
            outputs = self._chooseOutputs(outputs, rvOrders, rvOrders_lengths)

            loss, avg_acc, num_total = self._compute_loss(
                outputs, labels, rvOrders_lengths - 1)
            loss.backward()

            self.optim.step()

            epoch_loss.update(loss.item(), n=num_total)
            epoch_acc.update(avg_acc, n=num_total)
            tqdm_batch.set_postfix({
                "Loss": epoch_loss.avg,
                "Avg acc": epoch_acc.avg
            })

            self.summary_writer.add_scalars("epoch/loss",
                                            {'loss': epoch_loss.val},
                                            self.current_iteration)
            self.summary_writer.add_scalars("epoch/accuracy",
                                            {'accuracy': epoch_acc.val},
                                            self.current_iteration)

            self.current_iteration += 1

            if (batch_i + 1) % val_every == 0:
                self.validate()
                self.model.train()  # put back in training mode

        tqdm_batch.close()
Exemple #2
0
    def train_one_epoch(self):
        """
        One epoch of training
        """
        num_batches = self.train_len // self.config.batch_size
        tqdm_batch = tqdm(self.train_loader,
                          total=num_batches,
                          desc="[Epoch {}]".format(self.current_epoch))

        # num_batches = self.overfit_debug_len // self.config.batch_size
        # tqdm_batch = tqdm(self.overfit_debug_loader, total=num_batches,
        #                   desc="[Epoch {}]".format(self.current_epoch))

        self.model.train()

        epoch_loss = AverageMeter()
        epoch_acc = AverageMeter()

        for data_list in tqdm_batch:
            program_args, rvAssignments = data_list[:-4], data_list[-4]
            batch_size = len(rvAssignments)

            for i in range(len(program_args)):
                program_args[i] = program_args[i].to(self.device)

            rvAssignments = rvAssignments.to(self.device)
            labels_list = [
                rvAssignments[:, i] for i in range(rvAssignments.size(1))
            ]

            # reset optimiser gradients
            self.optim.zero_grad()

            # get outputs and predictions
            outputs_list = self.model(program_args)
            loss = self._compute_loss(outputs_list, labels_list)
            accuracies = self._compute_accuracy(outputs_list, labels_list)
            avg_acc = np.mean(accuracies)

            loss.backward()
            self.optim.step()

            epoch_loss.update(loss.item(), n=batch_size)
            epoch_acc.update(avg_acc, n=batch_size)
            tqdm_batch.set_postfix({
                "Loss": epoch_loss.avg,
                "Avg acc": epoch_acc.avg
            })

            self.summary_writer.add_scalars("epoch/loss",
                                            {'loss': epoch_loss.val},
                                            self.current_iteration)
            self.summary_writer.add_scalars("epoch/accuracy",
                                            {'accuracy': epoch_acc.val},
                                            self.current_iteration)

            self.current_iteration += 1

        tqdm_batch.close()
Exemple #3
0
    def _test(self, name, loader, length):
        """
        Returns validation accuracy. Unlike training, we compute accuracy
        (1) per decision node and (2) sorted by HEAD, BODY, and TAIL.
        """
        num_batches = length // self.config.batch_size
        tqdm_batch = tqdm(loader,
                          total=num_batches,
                          desc="[{}]".format(name.capitalize()),
                          disable=self.disable_progressbar)

        # set the model in validation mode
        self.model.eval()
        loss_meter = AverageMeter()
        avg_acc_meter = AverageMeter()

        with torch.no_grad():
            for batch_i, data_list in enumerate(tqdm_batch):
                X, (counts, y) = data_list
                batch_size = len(y)

                X = X.to(device=self.device, dtype=torch.float32)
                y = y.to(device=self.device, dtype=torch.long)

                scores = self.model(X)
                loss = F.cross_entropy(scores, y)
                ll = torch.softmax(scores, 1)
                preds = torch.argmax(ll, 1)
                accuracy = torch.sum(
                    preds == y).float().cpu().numpy() / y.size(0)

                # write data and summaries
                loss_meter.update(loss.item(), n=batch_size)
                avg_acc_meter.update(accuracy, n=batch_size)

                tqdm_batch.set_postfix({
                    "{} Loss".format(name.capitalize()):
                    loss_meter.avg,
                    "Avg acc":
                    avg_acc_meter.avg
                })

                self.summary_writer.add_scalars("{}/loss".format(name),
                                                {'loss': loss_meter.val},
                                                self.current_val_iteration)
                self.summary_writer.add_scalars(
                    "{}/accuracy".format(name),
                    {'accuracy': avg_acc_meter.val},
                    self.current_val_iteration)

                self.current_val_iteration += 1

        print('AVERAGE ACCURACY: {}'.format(avg_acc_meter.avg))
        tqdm_batch.close()
Exemple #4
0
    def train_one_epoch(self):
        """
        One epoch of training
        """
        num_batches = self.train_len // self.config.batch_size
        tqdm_batch = tqdm(self.train_loader,
                          total=num_batches,
                          desc="[Epoch {}]".format(self.current_epoch))

        # num_batches = self.overfit_debug_len // self.config.batch_size
        # tqdm_batch = tqdm(self.overfit_debug_loader, total=num_batches,
        #                   desc="[Epoch {}]".format(self.current_epoch))

        self.model.train()

        epoch_loss = AverageMeter()
        epoch_acc = AverageMeter()

        for seq_src, _, seq_len, label, _ in tqdm_batch:
            batch_size = len(seq_src)

            seq_src = seq_src.to(self.device)
            seq_len = seq_len.to(self.device)
            label = label.to(self.device)

            # reset optimiser gradients
            self.optim.zero_grad()

            # get outputs and predictions
            output = self.model(seq_src, seq_len)
            loss = self._compute_loss(output, label)
            acc = self._compute_accuracy(output, label)

            loss.backward()
            self.optim.step()

            epoch_loss.update(loss.item(), n=batch_size)
            epoch_acc.update(acc, n=batch_size)
            tqdm_batch.set_postfix({
                "Loss": epoch_loss.avg,
                "Avg acc": epoch_acc.avg
            })

            self.summary_writer.add_scalars("epoch/loss",
                                            {'loss': epoch_loss.val},
                                            self.current_iteration)
            self.summary_writer.add_scalars("epoch/accuracy",
                                            {'accuracy': epoch_acc.val},
                                            self.current_iteration)

            self.current_iteration += 1

        tqdm_batch.close()
Exemple #5
0
    def _test(self, name, loader, length):
        num_batches = length // self.config.batch_size
        tqdm_batch = tqdm(loader,
                          total=num_batches,
                          desc="[{}]".format(name.capitalize()))

        # set the model in validation mode
        self.model.eval()
        loss_meter = AverageMeter()
        accuracy_meter = AverageMeter()

        for data_list in tqdm_batch:
            program_args, rvAssignments = data_list[:-4], data_list[-4]
            batch_size = len(rvAssignments)

            for i in range(len(program_args)):
                program_args[i] = program_args[i].to(self.device)

            rvAssignments = rvAssignments.to(self.device)
            labels_list = [
                rvAssignments[:, i] for i in range(rvAssignments.size(1))
            ]

            outputs_list = self.model(program_args)
            loss = self._compute_loss(outputs_list, labels_list)
            accuracies = self._compute_accuracy(outputs_list, labels_list)
            avg_acc = np.mean(accuracies)

            loss_meter.update(loss.item(), n=batch_size)
            accuracy_meter.update(avg_acc, n=batch_size)

            tqdm_batch.set_postfix({
                "{} Loss".format(name.capitalize()):
                loss_meter.avg,
                "Avg acc":
                accuracy_meter.avg
            })

            self.summary_writer.add_scalars("{}/loss".format(name),
                                            {'loss': loss_meter.val},
                                            self.current_val_iteration)
            self.summary_writer.add_scalars("{}/accuracy".format(name),
                                            {'accuracy': accuracy_meter.val},
                                            self.current_val_iteration)

            self.current_val_iteration += 1

        tqdm_batch.close()
Exemple #6
0
    def _test(self, name, loader, length):
        """
        Returns validation accuracy. Unlike training, we compute accuracy
        (1) per decision node and (2) sorted by HEAD, BODY, and TAIL.
        """
        num_batches = length // self.config.batch_size
        tqdm_batch = tqdm(loader, total=num_batches,
                          desc="[{}]".format(name.capitalize()), disable=self.disable_progressbar)

        # set the model in validation mode
        self.model.eval()
        loss_meter = AverageMeter()
        avg_acc_meter = AverageMeter()

        tier_counts = np.zeros((3, self.num_rv_nodes + 1))
        tier_norm = np.zeros((3, self.num_rv_nodes + 1))

        for data_list in tqdm_batch:
            program_args, rvAssignments, tiers, rvOrders, rvOrders_lengths = \
                data_list[:-4], data_list[-4], data_list[-3], data_list[-2], data_list[-1]

            for i in range(len(program_args)):
                program_args[i] = program_args[i].to(self.device)

            # N x all_num_rvs tesnor containing (indexes of) final values of all rvs for each data point
            rvAssignments = rvAssignments.to(self.device)

            # rvOrders is a list of shape N x T where each row contains the (padded) render order for each data point
            # Each entry in columnt t is the index of the rv that was rendered at time t for this datapoint
            rvOrders = rvOrders.to(self.device)
            rvOrders_lengths = rvOrders_lengths.to(self.device)
            rvOrdersShifted = rvOrders[:, 1:]        # Shifted left by one to match the T-1 predictions made

            # shape N x (T - 1)
            labels = self._createLabels(rvOrders, rvAssignments, rvOrders_lengths)

            # outputs are list of list of tensors of shape num_batches x (T-1) x c_t
            output, alphas = self.model(rvOrders.long(), rvOrders_lengths.long(),
                                        rvAssignments.long(), program_args)

            loss, avg_acc, num_total = self._compute_loss(output, labels, rvOrders_lengths)
            tier_counts_, tier_norm_ = self._compute_tier_stats(output, labels, tiers, rvOrders_lengths, rvOrdersShifted)

            tier_counts += tier_counts_
            tier_norm += tier_norm_

            if self.use_attention:
                assert len(alphas) > 0
                frob_loss = self._compute_frobenius_norm(alphas)
                loss = loss + frob_loss

            # write data and summaries
            loss_meter.update(loss.item(), n=num_total)
            avg_acc_meter.update(avg_acc, n=num_total)

            tqdm_batch.set_postfix({"{} Loss".format(name.capitalize()): loss_meter.avg,
                                    "Avg acc": avg_acc_meter.avg})

            self.summary_writer.add_scalars("{}/loss".format(name), {'loss': loss_meter.val}, self.current_val_iteration)
            self.summary_writer.add_scalars("{}/accuracy".format(name), {'accuracy': avg_acc_meter.val}, self.current_val_iteration)

            self.current_val_iteration +=  1

        tier_accuracy = tier_counts / tier_norm
        head_acc = {'rv: {}'.format(self.train_dataset.rv_info['i2w'][str(idx)]): 
                    tier_accuracy[0, idx] for idx in range(self.num_rv_nodes + 1)}
        body_acc = {'rv: {}'.format(self.train_dataset.rv_info['i2w'][str(idx)]): 
                    tier_accuracy[1, idx] for idx in range(self.num_rv_nodes + 1)}
        tail_acc = {'rv: {}'.format(self.train_dataset.rv_info['i2w'][str(idx)]): 
                    tier_accuracy[2, idx] for idx in range(self.num_rv_nodes + 1)}

        acc_dict = {'head_acc': head_acc, 'body_acc': body_acc, 'tail_acc': tail_acc}
        self.accuracies.append(acc_dict)
        
        print('AVERAGE ACCURACY: {}'.format(avg_acc_meter.avg))


        self.early_stopping.update(loss_meter.avg)
        
        print('[HEAD] {} accuracy per RV ({} total): '.format(name.capitalize(), tier_norm[0, 1]))
        pprint(head_acc)

        print('[BODY] {} accuracy per RV ({} total): '.format(name.capitalize(), tier_norm[1, 1]))
        pprint(body_acc)

        print('[TAIL] {} accuracy per RV ({} total): '.format(name.capitalize(), tier_norm[2, 1]))
        pprint(tail_acc)

        tqdm_batch.close()

        return acc_dict 
Exemple #7
0
    def train_one_epoch(self):
        """
        One epoch of training
        """
        num_batches = self.train_len // self.config.batch_size
        tqdm_batch = tqdm(self.train_loader, total=num_batches,
                          desc="[Epoch {}]".format(self.current_epoch), disable=self.disable_progressbar)

        # num_batches = self.overfit_debug_len // self.config.batch_size
        # tqdm_batch = tqdm(self.overfit_debug_loader, total=num_batches,
        #                   desc="[Epoch {}]".format(self.current_epoch))

        val_every = None if self.config['validations_per_epoch'] == 0 else max(num_batches // self.config['validations_per_epoch'], 1)
        self.model.train()

        epoch_loss = AverageMeter()
        epoch_acc = AverageMeter()

        for batch_i, data_list in enumerate(tqdm_batch):
            program_args, rvAssignments, _, rvOrders, rvOrders_lengths = \
                data_list[:-4], data_list[-4], data_list[-3], data_list[-2], data_list[-1]

            for i in range(len(program_args)):
                program_args[i] = program_args[i].to(self.device)

            # N x all_num_rvs tesnor containing (indexes of) final values of all rvs for each data point
            rvAssignments = rvAssignments.to(self.device)

            # rvOrders is a list of shape N x T where each row contains the (padded) render order for each data point
            # Each entry in columnt t is the index of the rv that was rendered at time t for this datapoint
            rvOrders = rvOrders.to(self.device)
            rvOrders_lengths = rvOrders_lengths.to(self.device)

            # shape N x (T - 1)
            labels = self._createLabels(rvOrders, rvAssignments, rvOrders_lengths)

            # reset optimiser gradients
            self.optim.zero_grad()

            # outputs are list of list of tensors of shape num_batches x (T-1) x c_t
            output, alphas = self.model(rvOrders.long(), rvOrders_lengths.long(), rvAssignments.long(), program_args)

            loss, avg_acc, num_total = self._compute_loss(output, labels, rvOrders_lengths)

            if self.use_attention:
                assert len(alphas) > 0
                frob_loss = self._compute_frobenius_norm(alphas)
                loss = loss + frob_loss

            loss.backward()
            self.optim.step()

            epoch_loss.update(loss.item(), n=num_total)
            epoch_acc.update(avg_acc, n=num_total)
            tqdm_batch.set_postfix({"Loss": epoch_loss.avg, "Avg acc": epoch_acc.avg})

            self.summary_writer.add_scalars("epoch/loss", {'loss': epoch_loss.val}, self.current_iteration)
            self.summary_writer.add_scalars("epoch/accuracy", {'accuracy': epoch_acc.val}, self.current_iteration)

            self.current_iteration += 1

            if val_every and (batch_i + 1) % val_every == 0:
                self.validate()
                self.model.train()  # put back in training mode

        tqdm_batch.close()
Exemple #8
0
 def eval_fn(self, data_loader, return_predictions=False):
     losses = AverageMeter("loss")
     if self.metrics is not None:
         meters = Metrics(*self.metrics["validation"],
                          mode="validation",
                          return_predictions=return_predictions)
     else:
         meters = None
     self.model.to(self.params.device)
     self.model.eval()
     with torch.no_grad():
         iterator = tqdm(data_loader, total=len(data_loader))
         for b_idx, data in enumerate(iterator):
             self.to_device(data, self.params.device)
             if self.fp16:
                 with amp.autocast():
                     if isinstance(self.model, BaseEncoderModel):
                         if self.model.input_dict:
                             model_output = self.model(**data.to_dict())
                         else:
                             model_output = self.model(data)
                         if isinstance(model_output, ClassifierOutput):
                             loss = model_output.loss
                             logits = model_output.predictions
                         else:
                             logits = model_output
                             loss = None
                     else:
                         if not isinstance(data, dict):
                             features = data.to_dict()
                             labels = data.labels
                         else:
                             features = data
                             labels = data["labels"]
                         model_output = self.model(**features,
                                                   labels=labels)
                         loss = model_output[0]
                         logits = model_output[1]
             else:
                 if isinstance(self.model, BaseEncoderModel):
                     if self.model.input_dict:
                         model_output = self.model(**data.to_dict())
                     else:
                         model_output = self.model(data)
                     if isinstance(model_output, ClassifierOutput):
                         loss = model_output.loss
                         logits = model_output.predictions
                     else:
                         logits = model_output
                         loss = None
                 else:
                     if not isinstance(data, dict):
                         features = data.to_dict()
                         labels = data.labels
                     else:
                         features = data
                         labels = data["labels"]
                     model_output = self.model(**features, labels=labels)
                     loss = model_output[0]
                     logits = model_output[1]
             if loss is not None:
                 losses.update(loss.item(), self.params.batch_size)
             if meters is not None:
                 if isinstance(data, dict):
                     labels = data["labels"].cpu().numpy()
                 else:
                     labels = data.labels.cpu().numpy()
                 if logits is not None:
                     logits = logits.detach().cpu().numpy()
                     for m in meters.metrics:
                         m.update(logits, labels, n=self.params.batch_size)
                 postfix_dict = meters.set_postfix()
                 if loss is not None:
                     postfix_dict["loss"] = losses.avg
                 iterator.set_postfix(**postfix_dict)
         iterator.close()
     if self.verbose and meters is not None:
         meters.display_metrics()
     results = {"loss": losses.avg}
     if meters is not None:
         for m in meters.metrics:
             results[m.get_name] = m.avg
             if return_predictions:
                 results[f"predictions_{m.get_name}"] = m.all_predictions
                 results[f"labels_{m.get_name}"] = m.all_labels
             m.reset()
     return results
Exemple #9
0
    def train_fn(self, data_loader):
        if self.use_tpu:
            try:
                import torch_xla
                import torch_xla.core.xla_model as xm
            except:
                ImportError("Pytorch XLA is not available")
            self.device = xm.xla_device()
        losses = AverageMeter("loss")
        if self.metrics is not None:
            meters = Metrics(*self.metrics["training"])
        else:
            meters = None
        iterator = tqdm(data_loader, total=len(data_loader))
        self.model.to(self.params.device)
        self.model.train()
        results = []
        for b_idx, data in enumerate(iterator):
            self.to_device(data, self.params.device)
            skip_scheduler = False
            if self.use_tpu:
                loss, logits = self._tpu_step(data, b_idx)
            elif self.fp16:
                loss, logits, scale_before_step = self.mixed_precision_step(
                    data, b_idx)
                skip_scheduler = self.scaler.get_scale() != scale_before_step
            else:
                loss, logits = self.step(data, b_idx)

            if (b_idx + 1) % self.accumulation_steps == 0:
                self.optimizer.zero_grad()

            if self.scheduler is not None:
                if not skip_scheduler:
                    self.scheduler.step()
            if self.replacing_rate_scheduler is not None:
                self.replacing_rate_scheduler.step()

            losses.update(loss.item(), self.params.batch_size)
            if meters is not None:
                if isinstance(data, dict):
                    labels = data["labels"].cpu().numpy()
                else:
                    labels = data.labels.cpu().numpy()
                if logits is not None:
                    logits = logits.detach().cpu().numpy()
                    for m in meters.metrics:
                        m.update(logits, labels, n=self.params.batch_size)
                if not self.use_tpu:
                    iterator.set_postfix(loss=losses.avg,
                                         **meters.set_postfix())
            if not self.use_tpu:
                if meters is not None:
                    iterator.set_postfix(loss=losses.avg,
                                         **meters.set_postfix())
                else:
                    iterator.set_postfix({"loss": "{:.2f}".format(losses.avg)})
        if not self.use_tpu:
            iterator.close()
        if self.verbose and meters is not None and not self.use_tpu:
            meters.display_metrics()
        results = {"loss": losses.avg}
        if meters is not None:
            for m in meters.metrics:
                results[m.get_name] = m.avg
                m.reset()
        return results
Exemple #10
0
    def _test(self, name, loader, length):
        num_batches = length // self.config.batch_size
        tqdm_batch = tqdm(loader,
                          total=num_batches,
                          desc="[{}]".format(name.capitalize()))

        # set the model in validation mode
        self.model.eval()
        loss_meter = AverageMeter()
        accuracy_meter = AverageMeter()
        head_accuracy_meter = AverageMeter()
        tail_accuracy_meter = AverageMeter()

        for seq_src, _, seq_len, label, tier in tqdm_batch:
            batch_size = len(seq_src)

            seq_src = seq_src.to(self.device)
            seq_len = seq_len.to(self.device)
            label = label.to(self.device)

            output = self.model(seq_src, seq_len)
            loss = self._compute_loss(output, label)
            unrolled_acc = self._compute_accuracy(output, label, reduce=False)
            unrolled_acc = unrolled_acc[:, 0].astype(np.int)
            acc = np.mean(unrolled_acc)

            head_acc = np.mean(unrolled_acc[tier.numpy() == 1])
            tail_acc = np.mean(unrolled_acc[tier.numpy() == 0])

            loss_meter.update(loss.item(), n=batch_size)
            accuracy_meter.update(acc, n=batch_size)
            head_accuracy_meter.update(head_acc, sum(tier.numpy() == 1))
            tail_accuracy_meter.update(tail_acc, sum(tier.numpy() == 0))

            tqdm_batch.set_postfix({
                "{} Loss".format(name.capitalize()):
                loss_meter.avg,
                "Avg acc":
                accuracy_meter.avg
            })

            self.summary_writer.add_scalars("{}/loss".format(name),
                                            {'loss': loss_meter.val},
                                            self.current_val_iteration)
            self.summary_writer.add_scalars("{}/accuracy".format(name),
                                            {'accuracy': accuracy_meter.val},
                                            self.current_val_iteration)
            self.summary_writer.add_scalars(
                "{}/headAccuracy".format(name),
                {'headAccuracy': head_accuracy_meter.val},
                self.current_val_iteration)
            self.summary_writer.add_scalars(
                "{}/tailAccuracy".format(name),
                {'tailAccuracy': tail_accuracy_meter.val},
                self.current_val_iteration)

            self.current_val_iteration += 1

        tqdm_batch.close()
Exemple #11
0
    def _test(self, name, loader, length):
        num_batches = length // self.config.batch_size
        tqdm_batch = tqdm(loader,
                          total=num_batches,
                          desc="[{}]".format(name.capitalize()),
                          disable=self.disable_progressbar)

        # set the model in validation mode
        self.model.eval()
        loss_meter = AverageMeter()
        avg_acc_meter = AverageMeter()

        tier_counts = np.zeros((3, self.num_rv_nodes + 1))
        tier_norm = np.zeros((3, self.num_rv_nodes + 1))

        for data_list in tqdm_batch:
            program_args, rvAssignments, tiers, rvOrders, rvOrders_lengths = \
                data_list[:-4], data_list[-4], data_list[-3], data_list[-2], data_list[-1]

            for i in range(len(program_args)):
                program_args[i] = program_args[i].to(self.device)

            rvAssignments = rvAssignments.to(self.device)
            rvOrders = rvOrders.to(self.device)
            rvOrders_lengths = rvOrders_lengths.to(self.device)
            rvOrdersShifted = rvOrders[:, 1:]

            labels = self._createLabels(rvOrders, rvAssignments,
                                        rvOrders_lengths)
            outputs = self.model(program_args)
            outputs = self._chooseOutputs(outputs, rvOrders, rvOrders_lengths)

            loss, avg_acc, num_total = self._compute_loss(
                outputs, labels, rvOrders_lengths - 1)
            # tier_counts_, tier_norm_ = self._compute_tier_stats(outputs, labels, tiers, rvOrders_lengths, rvOrdersShifted)

            # tier_counts += tier_counts_
            # tier_norm += tier_norm_

            # write data and summaries
            loss_meter.update(loss.item(), n=num_total)
            avg_acc_meter.update(avg_acc, n=num_total)

            tqdm_batch.set_postfix({
                "{} Loss".format(name.capitalize()):
                loss_meter.avg,
                "Avg acc":
                avg_acc_meter.avg
            })

            self.summary_writer.add_scalars("{}/loss".format(name),
                                            {'loss': loss_meter.val},
                                            self.current_val_iteration)
            self.summary_writer.add_scalars("{}/accuracy".format(name),
                                            {'accuracy': avg_acc_meter.val},
                                            self.current_val_iteration)

            self.current_val_iteration += 1

        # tier_accuracy = tier_counts / tier_norm
        # head_acc = {'rv: {}'.format(self.train_dataset.rv_info['i2w'][str(idx)]):
        #             tier_accuracy[0, idx] for idx in range(self.num_rv_nodes + 1)}
        # body_acc = {'rv: {}'.format(self.train_dataset.rv_info['i2w'][str(idx)]):
        #             tier_accuracy[1, idx] for idx in range(self.num_rv_nodes + 1)}
        # tail_acc = {'rv: {}'.format(self.train_dataset.rv_info['i2w'][str(idx)]):
        #             tier_accuracy[2, idx] for idx in range(self.num_rv_nodes + 1)}

        # acc_dict = {'head_acc': head_acc, 'body_acc': body_acc, 'tail_acc': tail_acc}
        # self.accuracies.append(acc_dict)

        print('AVERAGE ACCURACY: {}'.format(avg_acc_meter.avg))

        self.early_stopping.update(loss_meter.avg)

        # print('[HEAD] {} accuracy per RV ({} total): '.format(name.capitalize(), tier_norm[0, 1]))
        # pprint(head_acc)

        # print('[BODY] {} accuracy per RV ({} total): '.format(name.capitalize(), tier_norm[1, 1]))
        # pprint(body_acc)

        # print('[TAIL] {} accuracy per RV ({} total): '.format(name.capitalize(), tier_norm[2, 1]))
        # pprint(tail_acc)

        tqdm_batch.close()
Exemple #12
0
    def evaluateAgent(self):
        num_batches = len(self.test_dataset) // self.config.batch_size
        tqdm_batch = tqdm(self.dataloader, total=num_batches)

        loss_meter = AverageMeter()
        avg_acc_meter = AverageMeter()

        cm = np.zeros((self.n_labels, self.n_labels))
        if self.n_labels == 13:
            group_acc_meter = AverageMeter()
            cm_group = np.zeros((5, 5))

        accuracy_data = []

        for _, data_list in enumerate(tqdm_batch):
            X, (counts, y) = data_list

            X = X.to(device=self.agent.device, dtype=torch.float32)
            y = y.to(device=self.agent.device, dtype=torch.long)

            if self.n_labels == 13:
                valid_mask = y < 13     # not predcting labels after samp13
                X = X[valid_mask]
                y = y[valid_mask]
                counts = counts[valid_mask]

            batch_size = len(y)

            scores = self.model(X)
            loss = F.cross_entropy(scores, y)
            ll = torch.softmax(scores, 1)
            preds = torch.argmax(ll, 1)
            accuracy = torch.sum(preds == y).float().cpu().numpy()/y.size(0)

            # import pdb; pdb.set_trace()
            y_np = y.cpu().numpy()
            preds_np = preds.cpu().numpy()
            counts_np = counts.cpu().numpy()
            correct = preds_np == y_np

            if self.n_labels == 13:
                y_group_np = self.group_lookup[y_np]
                preds_group_np = self.group_lookup[preds_np]
                correct_group = preds_group_np == y_group_np
                group_accuracy = np.sum(preds_group_np == y_group_np).astype(float)/y_group_np.shape[0]

            if self.n_labels == 13:
                accuracy_data.extend(zip(correct, correct_group, counts_np))
            else:
                accuracy_data.extend(zip(correct, counts_np))

            cm[y_np, preds_np] += 1
            if self.n_labels == 13:
                cm_group[y_group_np, preds_group_np] += 1

            # write data and summaries
            loss_meter.update(loss.item(), n=batch_size)
            avg_acc_meter.update(accuracy, n=batch_size)

            if self.n_labels == 13:
                group_acc_meter.update(group_accuracy, n=batch_size)
                tqdm_batch.set_postfix({"Group acc": group_acc_meter.avg,
                                        "Avg acc": avg_acc_meter.avg})
            else:
                tqdm_batch.set_postfix({"Avg acc": avg_acc_meter.avg})

        print('AVERAGE ACCURACY: {}'.format(avg_acc_meter.avg))

        results = {
            'avg_acc': avg_acc_meter.avg,
            'confusion_matrix': cm,
            'accuracy_data': accuracy_data
        }
        if self.n_labels == 13:
            results['confusion_matrix_group'] = cm_group

        return results
Exemple #13
0
    def train_one_epoch(self):
        """
        One epoch of training
        """
        num_batches = self.train_len // self.config.batch_size
        tqdm_batch = tqdm(self.train_loader,
                          total=num_batches,
                          desc="[Epoch {}]".format(self.current_epoch),
                          disable=self.disable_progressbar)

        # num_batches = self.overfit_debug_len // self.config.batch_size
        # tqdm_batch = tqdm(self.overfit_debug_loader, total=num_batches,
        #                   desc="[Epoch {}]".format(self.current_epoch))

        val_every = None if self.config['validations_per_epoch'] == 0 else max(
            num_batches // self.config['validations_per_epoch'], 1)
        self.model.train()

        epoch_loss = AverageMeter()
        epoch_acc = AverageMeter()

        for batch_i, data_list in enumerate(tqdm_batch):
            X, (counts, y) = data_list
            batch_size = len(y)

            X = X.to(device=self.device, dtype=torch.float32)
            y = y.to(device=self.device, dtype=torch.long)

            scores = self.model(X)
            loss = F.cross_entropy(scores, y)
            ll = torch.softmax(scores, 1)
            preds = torch.argmax(ll, 1)

            accuracy = torch.sum(preds == y).float().cpu().numpy() / y.size(0)

            # reset optimiser gradients
            self.optim.zero_grad()

            loss.backward()

            self.optim.step()

            epoch_loss.update(loss.item(), n=batch_size)
            epoch_acc.update(accuracy, n=batch_size)
            tqdm_batch.set_postfix({
                "Loss": epoch_loss.avg,
                "Avg acc": epoch_acc.avg
            })

            self.summary_writer.add_scalars("epoch/loss",
                                            {'loss': epoch_loss.val},
                                            self.current_iteration)
            self.summary_writer.add_scalars("epoch/accuracy",
                                            {'accuracy': epoch_acc.val},
                                            self.current_iteration)

            self.current_iteration += 1

            if val_every and (batch_i + 1) % val_every == 0:
                self.validate()
                self.model.train()  # put back in training mode

        tqdm_batch.close()