示例#1
0
    def test(self, epoch):
        batch_time = AverageMeter('Time', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        progress = ProgressMeter(len(self.test_loader),
                                 [batch_time, losses, top1, top5],
                                 prefix='Test: ')

        # switch to test mode
        self.model.eval()

        with torch.no_grad():
            end = time.time()
            for i, (images, target) in enumerate(self.test_loader):
                images = images.cuda()
                target = target.cuda()

                # compute output
                output, _ = self.model(images)
                loss = self.criterion(output, target)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % self.args.print_freq == 0 and self.args.local_rank == 0:
                    progress.display(i)

            if self.args.local_rank == 0:
                print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
                    top1=top1, top5=top5))
                self.writer.add_scalar('Test/Avg_Loss', losses.avg, epoch + 1)
                self.writer.add_scalar('Test/Avg_Top1', top1.avg, epoch + 1)
                self.writer.add_scalar('Test/Avg_Top5', top5.avg, epoch + 1)
                self.summary_graph_adj(self.writer, epoch + 1)
                self.summary_graph_histogram(self.writer, epoch + 1)

        return top1.avg
示例#2
0
    def train_epoch(self, epoch):
        batch_time = AverageMeter('Time', ':6.3f')
        data_time = AverageMeter('Data', ':6.3f')
        losses = AverageMeter('Loss', ':.4e')
        top1 = AverageMeter('Acc@1', ':6.2f')
        top5 = AverageMeter('Acc@5', ':6.2f')
        progress = ProgressMeter(len(self.train_loader),
                                 [batch_time, data_time, losses, top1, top5],
                                 prefix="Epoch: [{}]".format(epoch))

        # switch to train mode
        self.model.train()
        end = time.time()

        for i, (images, target) in enumerate(self.train_loader):
            # measure data loading time
            data_time.update(time.time() - end)

            images = images.cuda()
            target = target.cuda()

            # compute output
            self.optimizer.zero_grad()
            logits, logits_aux = self.model(images)
            loss = self.criterion(logits, target)
            if self.args.graph_wd > 0:
                graph_params = [
                    v for k, v in self.model.named_parameters()
                    if 'graph_weights' in k and v.requires_grad
                ]
                graph_l2 = 0
                for v in graph_params:
                    graph_l2 += (self.model.edge_act(v)**2).sum()
                loss += 0.5 * graph_l2 * self.args.graph_wd
            if self.args.auxiliary:
                loss_aux = self.criterion(logits_aux, target)
                loss += self.args.auxiliary_weight * loss_aux
            loss.backward()
            if self.args.grad_clip > 0:
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.args.grad_clip)
            self.optimizer.step()

            # measure accuracy and record loss
            acc1, acc5 = accuracy(logits, target, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))
            self.moving_loss = loss.item() if epoch == self.args.start_epoch and i == 0 else \
                (1. - self.mu) * self.moving_loss + self.mu * loss.item()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % self.args.print_freq == 0 and self.args.local_rank == 0:
                progress.display(i)
                niter = epoch * len(self.train_loader) + i
                self.writer.add_scalar('Train/Sec_per_batch', batch_time.avg,
                                       niter)
                self.writer.add_scalar('Train/Avg_Loss', losses.avg, niter)
                self.writer.add_scalar('Train/Avg_Top1', top1.avg, niter)
                self.writer.add_scalar('Train/Avg_Top5', top5.avg, niter)
                self.writer.add_scalar('Train/Moving_Loss', self.moving_loss,
                                       niter)