示例#1
0
    def optimize_epoch(self, optimizer, loader, epoch, validation=False):
        print(f"Starting epoch {epoch}, validation: {validation} " + "="*30,flush=True)

        loss_value = util.AverageMeter()
        # house keeping
        self.model.train()
        if self.lr_schedule(epoch+1)  != self.lr_schedule(epoch):
            files.save_checkpoint_all(self.checkpoint_dir, self.model, args.arch,
                                      optimizer, self.L, epoch, lowest=False, save_str='pre-lr-drop')
        lr = self.lr_schedule(epoch)
        for pg in optimizer.param_groups:
            pg['lr'] = lr
        XE = torch.nn.CrossEntropyLoss()
        for iter, (data, label, selected) in enumerate(loader):
            now = time.time()
            niter = epoch * len(loader) + iter

            if niter*args.batch_size >= self.optimize_times[-1]:
                ############ optimize labels #########################################
                self.model.headcount = 1
                print('Optimizaton starting', flush=True)
                with torch.no_grad():
                    _ = self.optimize_times.pop()
                    self.optimize_labels(niter)
            data = data.to(self.dev)
            mass = data.size(0)
            final = self.model(data)
            #################### train CNN ####################################################
            if self.hc == 1:
                loss = XE(final, self.L[0, selected])
            else:
                loss = torch.mean(torch.stack([XE(final[h],
                                                  self.L[h, selected]) for h in range(self.hc)]))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_value.update(loss.item(), mass)
            data = 0

            # some logging stuff ##############################################################
            if iter % args.log_iter == 0:
                if self.writer:
                    self.writer.add_scalar('lr', self.lr_schedule(epoch), niter)

                    print(niter, " Loss: {0:.3f}".format(loss.item()), flush=True)
                    print(niter, " Freq: {0:.2f}".format(mass/(time.time() - now)), flush=True)
                    if writer:
                        self.writer.add_scalar('Loss', loss.item(), niter)
                        if iter > 0:
                            self.writer.add_scalar('Freq(Hz)', mass/(time.time() - now), niter)


        # end of epoch logging ################################################################
        if self.writer and (epoch % args.log_intv == 0):
            util.write_conv(self.writer, self.model, epoch=epoch)

        files.save_checkpoint_all(self.checkpoint_dir, self.model, args.arch,
                                  optimizer,  self.L, epoch, lowest=False)

        return {'loss': loss_value.avg}
示例#2
0
    def train_on_epoch(self, optimizer, loader, epoch, validation=False):
        print(f"Starting epoch {epoch}, validation: {validation} " + "=" * 30, flush=True)

        loss_value = util.AverageMeter()
        # house keeping
        self.model.run()
        if self.lr_schedule(epoch + 1) != self.lr_schedule(epoch):
            files.save_checkpoint_all(
                self.checkpoint_dir, self.model, args.arch,
                optimizer, self.L, epoch, lowest=False, save_str='pre-lr-drop')
        lr = self.lr_schedule(epoch)
        for pg in optimizer.param_groups:
            pg['lr'] = lr
        criterion_fn = torch.nn.CrossEntropyLoss()
        for index, (data, label, selected) in enumerate(loader):
            start_tm = time.time()
            global_step = epoch * len(loader) + index

            if global_step * args.batch_size >= self.optimize_times[-1]:
                # optimize labels #########################################
                self.model.headcount = 1
                print('Optimizaton starting', flush=True)
                with torch.no_grad():
                    _ = self.optimize_times.pop()
                    self.update_assignment(global_step)
            data = data.to(self.device)
            mass = data.size(0)
            outputs = self.model(data)
            # train CNN ####################################################
            if self.num_heads == 1:
                loss = criterion_fn(outputs, self.L[0, selected])
            else:
                loss = torch.mean(torch.stack([
                    criterion_fn(outputs[head_index], self.L[head_index, selected]) for head_index in
                    range(self.num_heads)]
                ))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_value.update(loss.item(), mass)
            data = 0

            # some logging stuff ##############################################################
            if index % args.log_iter == 0 and self.writer:
                self.writer.add_scalar('lr', self.lr_schedule(epoch), global_step)

                print(global_step, f" Loss: {loss.item():.3f}", flush=True)
                print(global_step, f" Freq: {mass / (time.time() - start_tm):.2f}", flush=True)
                if writer:
                    self.writer.add_scalar('Loss', loss.item(), global_step)
                    if index > 0:
                        self.writer.add_scalar('Freq(Hz)', mass / (time.time() - start_tm), global_step)

        # end of epoch logging ################################################################
        if self.writer and (epoch % args.log_intv == 0):
            util.write_conv(self.writer, self.model, epoch=epoch)

        files.save_checkpoint_all(self.checkpoint_dir, self.model, args.arch, optimizer, self.L, epoch, lowest=False)

        return {'loss': loss_value.avg}
示例#3
0
    def optimize_epoch(self, model, optimizer, loader, epoch, validation=False):
        print(f"Starting epoch {epoch}, validation: {validation} " + "=" * 30)
        loss_value = AverageMeter()
        rotacc_value = AverageMeter()

        # house keeping
        if not validation:
            model.run()
            lr = self.lr_schedule(epoch)
            for pg in optimizer.param_groups:
                pg['lr'] = lr
        else:
            model.eval()

        XE = torch.nn.CrossEntropyLoss().to(self.dev)
        l_dl = 0  # len(loader)
        now = time.time()
        batch_time = MovingAverage(intertia=0.9)
        for iter, (data, label, selected) in enumerate(loader):
            now = time.time()

            if not validation:
                niter = epoch * len(loader.dataset) + iter * args.batch_size
            data = data.to(self.dev)
            mass = data.size(0)
            where = np.arange(mass, dtype=int) * 4
            data = data.view(mass * 4, 3, data.size(3), data.size(4))
            rotlabel = torch.tensor(range(4)).view(-1, 1).repeat(mass, 1).view(-1).to(self.dev)
            #################### train CNN ###########################################
            if not validation:
                final = model(data)
                if args.onlyrot:
                    loss = torch.Tensor([0]).to(self.dev)
                else:
                    if args.hc == 1:
                        loss = XE(final[0][where], self.L[selected])
                    else:
                        loss = torch.mean(
                            torch.stack([XE(final[k][where], self.L[k, selected]) for k in range(args.hc)]))
                rotloss = XE(final[-1], rotlabel)
                pred = torch.argmax(final[-1], 1)

                total_loss = loss + rotloss
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                correct = (pred == rotlabel).to(torch.float)
                rotacc = correct.sum() / float(mass)
            else:
                final = model(data)
                pred = torch.argmax(final[-1], 1)
                correct = (pred == rotlabel.cuda()).to(torch.float)
                rotacc = correct.sum() / float(mass)
                total_loss = torch.Tensor([0])
                loss = torch.Tensor([0])
                rotloss = torch.Tensor([0])
            rotacc_value.update(rotacc.item(), mass)
            loss_value.update(total_loss.item(), mass)

            batch_time.update(time.time() - now)
            now = time.time()
            print(
                f"Loss: {loss_value.avg:03.3f}, RotAcc: {rotacc_value.avg:03.3f} | {epoch: 3}/{iter:05}/{l_dl:05} Freq: {mass / batch_time.avg:04.1f}Hz:",
                end='\r', flush=True)

            # every few iter logging
            if iter % args.logiter == 0:
                if not validation:
                    print(niter, f" Loss: {loss.item():.3f}", flush=True)
                    with torch.no_grad():
                        if not args.onlyrot:
                            pred = torch.argmax(final[0][where], dim=1)
                            pseudoloss = XE(final[0][where], pred)
                    if not args.onlyrot:
                        self.writer.add_scalar('Pseudoloss', pseudoloss.item(), niter)
                    self.writer.add_scalar('lr', self.lr_schedule(epoch), niter)
                    self.writer.add_scalar('Loss', loss.item(), niter)
                    self.writer.add_scalar('RotLoss', rotloss.item(), niter)
                    self.writer.add_scalar('RotAcc', rotacc.item(), niter)

                    if iter > 0:
                        self.writer.add_scalar('Freq(Hz)', mass / (time.time() - now), niter)

        # end of epoch logging
        if self.writer and (epoch % self.log_interval == 0):
            write_conv(self.writer, model, epoch)
            if validation:
                print('val Rot-Acc: ', rotacc_value.avg)
                self.writer.add_scalar('val Rot-Acc', rotacc_value.avg, epoch)

        files.save_checkpoint_all(self.checkpoint_dir, model, args.arch,
                                  optimizer, self.L, epoch, lowest=False)
        return {'loss': loss_value.avg}