示例#1
0
    def optimize_epoch(self, optimizer, loader, epoch, validation=False):
        print(f"Starting epoch {epoch}, validation: {validation} " + "="*30,flush=True)

        loss_value = util.AverageMeter()
        # house keeping
        self.model.train()
        if self.lr_schedule(epoch+1)  != self.lr_schedule(epoch):
            files.save_checkpoint_all(self.checkpoint_dir, self.model, args.arch,
                                      optimizer, self.L, epoch, lowest=False, save_str='pre-lr-drop')
        lr = self.lr_schedule(epoch)
        for pg in optimizer.param_groups:
            pg['lr'] = lr
        XE = torch.nn.CrossEntropyLoss()
        for iter, (data, label, selected) in enumerate(loader):
            now = time.time()
            niter = epoch * len(loader) + iter

            if niter*args.batch_size >= self.optimize_times[-1]:
                ############ optimize labels #########################################
                self.model.headcount = 1
                print('Optimizaton starting', flush=True)
                with torch.no_grad():
                    _ = self.optimize_times.pop()
                    self.optimize_labels(niter)
            data = data.to(self.dev)
            mass = data.size(0)
            final = self.model(data)
            #################### train CNN ####################################################
            if self.hc == 1:
                loss = XE(final, self.L[0, selected])
            else:
                loss = torch.mean(torch.stack([XE(final[h],
                                                  self.L[h, selected]) for h in range(self.hc)]))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_value.update(loss.item(), mass)
            data = 0

            # some logging stuff ##############################################################
            if iter % args.log_iter == 0:
                if self.writer:
                    self.writer.add_scalar('lr', self.lr_schedule(epoch), niter)

                    print(niter, " Loss: {0:.3f}".format(loss.item()), flush=True)
                    print(niter, " Freq: {0:.2f}".format(mass/(time.time() - now)), flush=True)
                    if writer:
                        self.writer.add_scalar('Loss', loss.item(), niter)
                        if iter > 0:
                            self.writer.add_scalar('Freq(Hz)', mass/(time.time() - now), niter)


        # end of epoch logging ################################################################
        if self.writer and (epoch % args.log_intv == 0):
            util.write_conv(self.writer, self.model, epoch=epoch)

        files.save_checkpoint_all(self.checkpoint_dir, self.model, args.arch,
                                  optimizer,  self.L, epoch, lowest=False)

        return {'loss': loss_value.avg}
示例#2
0
    def train_on_epoch(self, optimizer, loader, epoch, validation=False):
        print(f"Starting epoch {epoch}, validation: {validation} " + "=" * 30, flush=True)

        loss_value = util.AverageMeter()
        # house keeping
        self.model.run()
        if self.lr_schedule(epoch + 1) != self.lr_schedule(epoch):
            files.save_checkpoint_all(
                self.checkpoint_dir, self.model, args.arch,
                optimizer, self.L, epoch, lowest=False, save_str='pre-lr-drop')
        lr = self.lr_schedule(epoch)
        for pg in optimizer.param_groups:
            pg['lr'] = lr
        criterion_fn = torch.nn.CrossEntropyLoss()
        for index, (data, label, selected) in enumerate(loader):
            start_tm = time.time()
            global_step = epoch * len(loader) + index

            if global_step * args.batch_size >= self.optimize_times[-1]:
                # optimize labels #########################################
                self.model.headcount = 1
                print('Optimizaton starting', flush=True)
                with torch.no_grad():
                    _ = self.optimize_times.pop()
                    self.update_assignment(global_step)
            data = data.to(self.device)
            mass = data.size(0)
            outputs = self.model(data)
            # train CNN ####################################################
            if self.num_heads == 1:
                loss = criterion_fn(outputs, self.L[0, selected])
            else:
                loss = torch.mean(torch.stack([
                    criterion_fn(outputs[head_index], self.L[head_index, selected]) for head_index in
                    range(self.num_heads)]
                ))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_value.update(loss.item(), mass)
            data = 0

            # some logging stuff ##############################################################
            if index % args.log_iter == 0 and self.writer:
                self.writer.add_scalar('lr', self.lr_schedule(epoch), global_step)

                print(global_step, f" Loss: {loss.item():.3f}", flush=True)
                print(global_step, f" Freq: {mass / (time.time() - start_tm):.2f}", flush=True)
                if writer:
                    self.writer.add_scalar('Loss', loss.item(), global_step)
                    if index > 0:
                        self.writer.add_scalar('Freq(Hz)', mass / (time.time() - start_tm), global_step)

        # end of epoch logging ################################################################
        if self.writer and (epoch % args.log_intv == 0):
            util.write_conv(self.writer, self.model, epoch=epoch)

        files.save_checkpoint_all(self.checkpoint_dir, self.model, args.arch, optimizer, self.L, epoch, lowest=False)

        return {'loss': loss_value.avg}
示例#3
0
    def optimize(self):
        """Perform full optimization."""
        first_epoch = 0
        self.model = self.model.to(self.dev)
        N = len(self.pseudo_loader.dataset)
        # optimization times (spread exponentially), can also just be linear in practice (i.e. every n-th epoch)
        self.optimize_times = [(self.num_epochs+2)*N] + \
                              ((self.num_epochs+1.01)*N*(np.linspace(0, 1, args.nopts)**2)[::-1]).tolist()

        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           self.model.parameters()),
                                    weight_decay=self.weight_decay,
                                    momentum=self.momentum,
                                    lr=self.lr)

        if self.checkpoint_dir is not None and self.resume:
            self.L, first_epoch = files.load_checkpoint_all(
                self.checkpoint_dir, self.model, optimizer)
            print('found first epoch to be', first_epoch, flush=True)
            include = [(qq / N >= first_epoch) for qq in self.optimize_times]
            self.optimize_times = (np.array(
                self.optimize_times)[include]).tolist()
        print('We will optimize L at epochs:',
              [np.round(1.0 * t / N, 2) for t in self.optimize_times],
              flush=True)

        if first_epoch == 0:
            # initiate labels as shuffled.
            self.L = np.zeros((self.hc, N), dtype=np.int32)
            for nh in range(self.hc):
                for _i in range(N):
                    self.L[nh, _i] = _i % self.outs[nh]
                self.L[nh] = np.random.permutation(self.L[nh])
            self.L = torch.LongTensor(self.L).to(self.dev)

        # Perform optmization ###############################################################
        lowest_loss = 1e9
        epoch = first_epoch
        while epoch < (self.num_epochs + 1):
            m = self.optimize_epoch(optimizer,
                                    self.train_loader,
                                    epoch,
                                    validation=False)
            if m['loss'] < lowest_loss:
                lowest_loss = m['loss']
                files.save_checkpoint_all(self.checkpoint_dir,
                                          self.model,
                                          args.arch,
                                          optimizer,
                                          self.L,
                                          epoch,
                                          lowest=True)
            epoch += 1
        print(
            f"optimization completed. Saving model to {os.path.join(self.checkpoint_dir,'model_final.pth.tar')}"
        )
        torch.save(self.model,
                   os.path.join(self.checkpoint_dir, 'model_final.pth.tar'))
        return self.model
    def optimize(self, model, train_loader):
        """Perform full optimization."""
        first_epoch = 0
        model = model.to(self.dev)
        self.optimize_times = [0]
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    weight_decay=self.weight_decay,
                                    momentum=self.momentum,
                                    lr=self.lr)
        if self.checkpoint_dir is not None and self.resume:
            self.L, first_epoch = files.load_checkpoint_all(
                self.checkpoint_dir, model=None, opt=None)
            print('loaded from: ', self.checkpoint_dir, flush=True)
            print('first five entries of L: ', self.L[:5], flush=True)
            print('found first epoch to be', first_epoch, flush=True)
            first_epoch = 0
            self.optimize_times = [0]
            self.L = self.L.cuda()
            print("model.headcount ", model.headcount, flush=True)

        #####################################################################################
        # Perform optmization ###############################################################
        lowest_loss = 1e9
        epoch = first_epoch
        while epoch < (self.num_epochs + 1):
            if not args.val_only:
                m = self.optimize_epoch(model,
                                        optimizer,
                                        train_loader,
                                        epoch,
                                        validation=False)
                if m['loss'] < lowest_loss:
                    lowest_loss = m['loss']
                    files.save_checkpoint_all(self.checkpoint_dir,
                                              model,
                                              args.arch,
                                              optimizer,
                                              self.L,
                                              epoch,
                                              lowest=True)
            else:
                print('=' * 30 + ' doing only validation ' + "=" * 30)
                epoch = self.num_epochs
            m = self.optimize_epoch(model,
                                    optimizer,
                                    self.val_loader,
                                    epoch,
                                    validation=True)
            epoch += 1
        print(
            f"Model optimization completed. Saving final model to {os.path.join(self.checkpoint_dir, 'model_final.pth.tar')}"
        )
        torch.save(model,
                   os.path.join(self.checkpoint_dir, 'model_final.pth.tar'))
        return model
示例#5
0
            loss = loss_func(pre_label, label)

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if iter % 25 == 0:
                print('epoch:{}, loss:{:.4f}'.format(epoch, loss.item()))
                writer.add_scalar('loss', loss.item(),
                                  iter + epoch * len(train_loader))
                writer.add_scalar('lr', lr, iter + epoch * len(train_loader))

        # 保存checkpoints
        files.save_checkpoint_all(checkpoint_dir,
                                  model,
                                  'alexnet',
                                  optimizer,
                                  pre_label,
                                  epoch,
                                  lowest=False)

        _, predicted = torch.max(pre_label.data, 1)
        total = label.size(0)
        correct = (predicted == label).sum()
        print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct // total)))
        writer.add_scalar('correct', (100 * correct // total),
                          (epoch + 1) * len(train_loader))

    writer.close()
示例#6
0
    def optimize_epoch(self, model, optimizer, loader, epoch, validation=False):
        print(f"Starting epoch {epoch}, validation: {validation} " + "=" * 30)
        loss_value = AverageMeter()
        rotacc_value = AverageMeter()

        # house keeping
        if not validation:
            model.run()
            lr = self.lr_schedule(epoch)
            for pg in optimizer.param_groups:
                pg['lr'] = lr
        else:
            model.eval()

        XE = torch.nn.CrossEntropyLoss().to(self.dev)
        l_dl = 0  # len(loader)
        now = time.time()
        batch_time = MovingAverage(intertia=0.9)
        for iter, (data, label, selected) in enumerate(loader):
            now = time.time()

            if not validation:
                niter = epoch * len(loader.dataset) + iter * args.batch_size
            data = data.to(self.dev)
            mass = data.size(0)
            where = np.arange(mass, dtype=int) * 4
            data = data.view(mass * 4, 3, data.size(3), data.size(4))
            rotlabel = torch.tensor(range(4)).view(-1, 1).repeat(mass, 1).view(-1).to(self.dev)
            #################### train CNN ###########################################
            if not validation:
                final = model(data)
                if args.onlyrot:
                    loss = torch.Tensor([0]).to(self.dev)
                else:
                    if args.hc == 1:
                        loss = XE(final[0][where], self.L[selected])
                    else:
                        loss = torch.mean(
                            torch.stack([XE(final[k][where], self.L[k, selected]) for k in range(args.hc)]))
                rotloss = XE(final[-1], rotlabel)
                pred = torch.argmax(final[-1], 1)

                total_loss = loss + rotloss
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()
                correct = (pred == rotlabel).to(torch.float)
                rotacc = correct.sum() / float(mass)
            else:
                final = model(data)
                pred = torch.argmax(final[-1], 1)
                correct = (pred == rotlabel.cuda()).to(torch.float)
                rotacc = correct.sum() / float(mass)
                total_loss = torch.Tensor([0])
                loss = torch.Tensor([0])
                rotloss = torch.Tensor([0])
            rotacc_value.update(rotacc.item(), mass)
            loss_value.update(total_loss.item(), mass)

            batch_time.update(time.time() - now)
            now = time.time()
            print(
                f"Loss: {loss_value.avg:03.3f}, RotAcc: {rotacc_value.avg:03.3f} | {epoch: 3}/{iter:05}/{l_dl:05} Freq: {mass / batch_time.avg:04.1f}Hz:",
                end='\r', flush=True)

            # every few iter logging
            if iter % args.logiter == 0:
                if not validation:
                    print(niter, f" Loss: {loss.item():.3f}", flush=True)
                    with torch.no_grad():
                        if not args.onlyrot:
                            pred = torch.argmax(final[0][where], dim=1)
                            pseudoloss = XE(final[0][where], pred)
                    if not args.onlyrot:
                        self.writer.add_scalar('Pseudoloss', pseudoloss.item(), niter)
                    self.writer.add_scalar('lr', self.lr_schedule(epoch), niter)
                    self.writer.add_scalar('Loss', loss.item(), niter)
                    self.writer.add_scalar('RotLoss', rotloss.item(), niter)
                    self.writer.add_scalar('RotAcc', rotacc.item(), niter)

                    if iter > 0:
                        self.writer.add_scalar('Freq(Hz)', mass / (time.time() - now), niter)

        # end of epoch logging
        if self.writer and (epoch % self.log_interval == 0):
            write_conv(self.writer, model, epoch)
            if validation:
                print('val Rot-Acc: ', rotacc_value.avg)
                self.writer.add_scalar('val Rot-Acc', rotacc_value.avg, epoch)

        files.save_checkpoint_all(self.checkpoint_dir, model, args.arch,
                                  optimizer, self.L, epoch, lowest=False)
        return {'loss': loss_value.avg}