Esempio n. 1
0
    def train_epoch(self):
        """
        trains one epoch, is invoked by train function. Usually not necessary to be called outside.

        :return: train metric result
        """
        network = self.state['network']
        network = network.train() or network
        optimizer = self.state['optimizer']
        loss = self.state['loss']
        loss_metric = AverageValue()
        metric = self.state['metric']
        metric.reset()
        policy = self.state['train_policy']
        n = self.state['train_iter'].epoch_iterations
        m = self.state['train_iter'].num_mini_batches
        printer = make_printer(desc='TRAIN',
                               total=n,
                               disable=self.state['quiet'])
        train_metric = dict()
        self._run_hooks('train_begin')
        for iteration, mini_batch in self.state['train_iter']:
            optimizer.zero_grad()
            for sample in mini_batch:
                self._run_hooks('train_pre_forward', sample=sample)
                output = network.forward(sample)
                l = loss(output)
                # _show([sample[0][0]['image']], ms=0)
                train_metric.update(metric(output), loss=loss_metric(l).item())
                if m > 1:
                    l /= m
                self._run_hooks('train_forward',
                                metric=train_metric,
                                loss=l,
                                output=output)
                l.backward()
                self._run_hooks('train_backward',
                                metric=train_metric,
                                loss=l,
                                output=output)

                # save two input and output images every epoch
                if iteration % n == 0:

                    save_image(sample['image'][0], 'input', self.state['rank'],
                               str(int(iteration / self.state['epoch'])))
                    save_image(output['output'][0], 'output',
                               self.state['rank'],
                               str(int(iteration / self.state['epoch'])))
            policy.step(iteration / n, train_metric['loss'])
            optimizer.step()
            printer(**train_metric)
        self.state['train_metric_values'].append(train_metric)
        self._run_hooks('train_end')
        return train_metric
Esempio n. 2
0
    def validate_epoch(self, epoch):
        """
        Validate once.
        Invoked by train function.
        Usually not necessary to be called outside.

        :return: val metric result
        """
        network = self.state['network']
        network = network.eval() or network
        loss = self.state['val_loss'] or self.state['loss']
        loss_metric = AverageValue()
        classfier_loss = self.state['classfier_loss']
        metric = self.state['val_metric'] or self.state['metric']
        metric.reset()
        policy = self.state['val_policy']
        n = self.state['val_iter'].epoch_iterations
        printer = make_printer(desc='VAL',
                               total=n,
                               disable=self.state['quiet'])
        val_metric = dict()
        self._run_hooks('val_begin')
        for iteration, mini_batch in self.state['val_iter']:
            for sample in mini_batch:
                self._run_hooks('val_pre_forward', sample=sample)
                with torch.no_grad():
                    output = network.forward(sample)
                    l = loss(output)
                    l2 = classfier_loss(output)
                    l += l2
                val_metric.update(
                    metric(output),
                    loss=loss_metric(l).item(),
                )
                self._run_hooks('val_forward',
                                metric=val_metric,
                                loss=l,
                                output=output)
            printer(**val_metric)
        policy.step(epoch, val_metric['loss'])
        self.state['val_metric_values'].append(val_metric)
        self._run_hooks('val_end')
        return val_metric
Esempio n. 3
0
    def train_epoch(self):
        """
        trains one epoch, is invoked by train function. Usually not necessary to be called outside.

        :return: train metric result
        """
        network = self.state['network']
        network = network.train() or network
        optimizer = self.state['optimizer']
        loss = self.state['loss']
        classfier_loss = self.state['classfier_loss']
        loss_metric = AverageValue()
        metric = self.state['metric']
        metric.reset()
        policy = self.state['train_policy']
        n = self.state['train_iter'].epoch_iterations
        #print('epoch_iterations ',n)
        m = self.state['train_iter'].num_mini_batches

        printer = make_printer(desc='TRAIN',
                               total=n,
                               disable=self.state['quiet'])
        train_metric = dict()
        self._run_hooks('train_begin')
        for iteration, mini_batch in self.state['train_iter']:
            optimizer.zero_grad()
            for sample in mini_batch:
                self._run_hooks('train_pre_forward', sample=sample)
                output = network.forward(sample)
                l = loss(output)

                # at the moment classifier loss is much bigger than autoencoder loss
                l2 = classfier_loss(output)
                # weigh the loss term such that classifier loss does not overwhelm overall loss
                alpha = (l / l2) + 0.05
                total_loss = l + (alpha * l2)
                train_metric.update(metric(output),
                                    AE_loss=l.item(),
                                    classifier_loss=l2.item(),
                                    total_loss=loss_metric(total_loss).item())

                if m > 1:
                    total_loss /= m
                self._run_hooks('train_forward',
                                metric=train_metric,
                                loss=total_loss,
                                output=output)
                total_loss.backward()
                self._run_hooks('train_backward',
                                metric=train_metric,
                                loss=total_loss,
                                output=output)

            # policy depends on iterations and I am not able to control iterations
            policy.step(iteration / n, train_metric['total_loss'])
            #print('Iterations ran ', iteration)
            optimizer.step()
            printer(**train_metric)

        self.state['train_metric_values'].append(train_metric)
        self._run_hooks('train_end')
        return train_metric
    def train_epoch(self):
        """
        trains one epoch, is invoked by train function. Usually not necessary to be called outside.

        :return: train metric result
        """
        network = self.state['network']
        network = network.train() or network
        optimizer = self.state['optimizer']
        loss = self.state['loss']
        loss_metric = AverageValue()
        metric = self.state['metric']
        metric.reset()
        policy = self.state['train_policy']
        n = self.state['train_iter'].epoch_iterations
        # print('epoch_iterations ',n)
        m = self.state['train_iter'].num_mini_batches

        printer = make_printer(desc='TRAIN',
                               total=n,
                               disable=self.state['quiet'])
        train_metric = dict()
        self._run_hooks('train_begin')
        for iteration, mini_batch in self.state['train_iter']:
            optimizer.zero_grad()
            for sample in mini_batch:
                self._run_hooks('train_pre_forward', sample=sample)
                output = network.forward(sample)
                l = loss(output)
                train_metric.update(
                    metric(output),
                    # the AverageMetric gives us the average loss over iterations so far
                    loss=loss_metric(l).item())

                if m > 1:
                    total_loss /= m
                self._run_hooks('train_forward',
                                metric=train_metric,
                                loss=l,
                                output=output)
                l.backward()
                self._run_hooks('train_backward',
                                metric=train_metric,
                                loss=l,
                                output=output)

                # save two input and output images every epoch
                if iteration % int(n / 2) == 0:

                    save_images(sample['image'][0], 'input',
                                str(iteration // self.state['epoch']))
                    save_images(output['output'][0], 'output',
                                str(iteration // self.state['epoch']))
                    save_images(
                        output['target_image'][0], 'output',
                        'target_' + str(iteration // self.state['epoch']))

            # policy depends on iterations and I am not able to control iterations
            policy.step(iteration / n, train_metric['loss'])
            # print('Iterations ran ', iteration)
            optimizer.step()
            printer(**train_metric)

        self.state['train_metric_values'].append(train_metric)
        self._run_hooks('train_end')
        return train_metric
Esempio n. 5
0
    def train_epoch(self):
        """
        trains one epoch, is invoked by train function. Usually not necessary to be called outside.

        :return: train metric result
        """
        network = self.state['network']
        network = network.train() or network
        optimizer = self.state['optimizer']
        loss = self.state['loss']
        loss_metric = AverageValue()
        cl_loss_metric = AverageValue()
        ae_loss_metric = AverageValue()
        metric = self.state['metric']
        metric.reset()
        policy = self.state['train_policy']
        n = self.state['train_iter'].epoch_iterations
        #print('epoch_iterations ',n)
        m = self.state['train_iter'].num_mini_batches

        printer = make_printer(desc='TRAIN',
                               total=n,
                               disable=self.state['quiet'])
        train_metric = dict()
        self._run_hooks('train_begin')
        for iteration, mini_batch in self.state['train_iter']:
            optimizer.zero_grad()
            for sample in mini_batch:
                self._run_hooks('train_pre_forward', sample=sample)

                output, log_vars = network.forward(sample)
                l, log_vars, both_loss = loss(output, log_vars)

                train_metric.update(
                    metric(output),
                    loss=loss_metric(l.item()),
                    sigma_cl=log_vars[0].item(),
                    sigma_ae=log_vars[1].item(),
                    cl_loss=cl_loss_metric(both_loss[0].item()),
                    ae_loss=ae_loss_metric(both_loss[1].item()))

                if m > 1:
                    l /= m
                self._run_hooks('train_forward',
                                metric=train_metric,
                                loss=l,
                                output=output)
                l.backward()
                self._run_hooks('train_backward',
                                metric=train_metric,
                                loss=l,
                                output=output)

                # save two input and output images every epoch
                if iteration % n == 0:

                    save_image_grid(sample['image'][0], 'input',
                                    self.state['rank'],
                                    str(int(iteration / self.state['epoch'])),
                                    self.state['outdir'])
                    save_image_grid(output['output'][0:9], 'output',
                                    self.state['rank'],
                                    str(int(iteration / self.state['epoch'])),
                                    self.state['outdir'])

            # policy depends on iterations and I am not able to control iterations
            policy.step(iteration / n, train_metric['loss'])
            #print('Iterations ran ', iteration)
            optimizer.step()
            printer(**train_metric)

        self.state['train_metric_values'].append(train_metric)
        self._run_hooks('train_end')
        return train_metric