def optimize_epoch(self, optimizer, loader, epoch, validation=False): print(f"Starting epoch {epoch}, validation: {validation} " + "="*30,flush=True) loss_value = util.AverageMeter() # house keeping self.model.train() if self.lr_schedule(epoch+1) != self.lr_schedule(epoch): files.save_checkpoint_all(self.checkpoint_dir, self.model, args.arch, optimizer, self.L, epoch, lowest=False, save_str='pre-lr-drop') lr = self.lr_schedule(epoch) for pg in optimizer.param_groups: pg['lr'] = lr XE = torch.nn.CrossEntropyLoss() for iter, (data, label, selected) in enumerate(loader): now = time.time() niter = epoch * len(loader) + iter if niter*args.batch_size >= self.optimize_times[-1]: ############ optimize labels ######################################### self.model.headcount = 1 print('Optimizaton starting', flush=True) with torch.no_grad(): _ = self.optimize_times.pop() self.optimize_labels(niter) data = data.to(self.dev) mass = data.size(0) final = self.model(data) #################### train CNN #################################################### if self.hc == 1: loss = XE(final, self.L[0, selected]) else: loss = torch.mean(torch.stack([XE(final[h], self.L[h, selected]) for h in range(self.hc)])) optimizer.zero_grad() loss.backward() optimizer.step() loss_value.update(loss.item(), mass) data = 0 # some logging stuff ############################################################## if iter % args.log_iter == 0: if self.writer: self.writer.add_scalar('lr', self.lr_schedule(epoch), niter) print(niter, " Loss: {0:.3f}".format(loss.item()), flush=True) print(niter, " Freq: {0:.2f}".format(mass/(time.time() - now)), flush=True) if writer: self.writer.add_scalar('Loss', loss.item(), niter) if iter > 0: self.writer.add_scalar('Freq(Hz)', mass/(time.time() - now), niter) # end of epoch logging ################################################################ if self.writer and (epoch % args.log_intv == 0): util.write_conv(self.writer, self.model, epoch=epoch) files.save_checkpoint_all(self.checkpoint_dir, self.model, args.arch, optimizer, self.L, epoch, lowest=False) return {'loss': loss_value.avg}
def train_on_epoch(self, optimizer, loader, epoch, validation=False): print(f"Starting epoch {epoch}, validation: {validation} " + "=" * 30, flush=True) loss_value = util.AverageMeter() # house keeping self.model.run() if self.lr_schedule(epoch + 1) != self.lr_schedule(epoch): files.save_checkpoint_all( self.checkpoint_dir, self.model, args.arch, optimizer, self.L, epoch, lowest=False, save_str='pre-lr-drop') lr = self.lr_schedule(epoch) for pg in optimizer.param_groups: pg['lr'] = lr criterion_fn = torch.nn.CrossEntropyLoss() for index, (data, label, selected) in enumerate(loader): start_tm = time.time() global_step = epoch * len(loader) + index if global_step * args.batch_size >= self.optimize_times[-1]: # optimize labels ######################################### self.model.headcount = 1 print('Optimizaton starting', flush=True) with torch.no_grad(): _ = self.optimize_times.pop() self.update_assignment(global_step) data = data.to(self.device) mass = data.size(0) outputs = self.model(data) # train CNN #################################################### if self.num_heads == 1: loss = criterion_fn(outputs, self.L[0, selected]) else: loss = torch.mean(torch.stack([ criterion_fn(outputs[head_index], self.L[head_index, selected]) for head_index in range(self.num_heads)] )) optimizer.zero_grad() loss.backward() optimizer.step() loss_value.update(loss.item(), mass) data = 0 # some logging stuff ############################################################## if index % args.log_iter == 0 and self.writer: self.writer.add_scalar('lr', self.lr_schedule(epoch), global_step) print(global_step, f" Loss: {loss.item():.3f}", flush=True) print(global_step, f" Freq: {mass / (time.time() - start_tm):.2f}", flush=True) if writer: self.writer.add_scalar('Loss', loss.item(), global_step) if index > 0: self.writer.add_scalar('Freq(Hz)', mass / (time.time() - start_tm), global_step) # end of epoch logging ################################################################ if self.writer and (epoch % args.log_intv == 0): util.write_conv(self.writer, self.model, epoch=epoch) files.save_checkpoint_all(self.checkpoint_dir, self.model, args.arch, optimizer, self.L, epoch, lowest=False) return {'loss': loss_value.avg}
def optimize(self): """Perform full optimization.""" first_epoch = 0 self.model = self.model.to(self.dev) N = len(self.pseudo_loader.dataset) # optimization times (spread exponentially), can also just be linear in practice (i.e. every n-th epoch) self.optimize_times = [(self.num_epochs+2)*N] + \ ((self.num_epochs+1.01)*N*(np.linspace(0, 1, args.nopts)**2)[::-1]).tolist() optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self.model.parameters()), weight_decay=self.weight_decay, momentum=self.momentum, lr=self.lr) if self.checkpoint_dir is not None and self.resume: self.L, first_epoch = files.load_checkpoint_all( self.checkpoint_dir, self.model, optimizer) print('found first epoch to be', first_epoch, flush=True) include = [(qq / N >= first_epoch) for qq in self.optimize_times] self.optimize_times = (np.array( self.optimize_times)[include]).tolist() print('We will optimize L at epochs:', [np.round(1.0 * t / N, 2) for t in self.optimize_times], flush=True) if first_epoch == 0: # initiate labels as shuffled. self.L = np.zeros((self.hc, N), dtype=np.int32) for nh in range(self.hc): for _i in range(N): self.L[nh, _i] = _i % self.outs[nh] self.L[nh] = np.random.permutation(self.L[nh]) self.L = torch.LongTensor(self.L).to(self.dev) # Perform optmization ############################################################### lowest_loss = 1e9 epoch = first_epoch while epoch < (self.num_epochs + 1): m = self.optimize_epoch(optimizer, self.train_loader, epoch, validation=False) if m['loss'] < lowest_loss: lowest_loss = m['loss'] files.save_checkpoint_all(self.checkpoint_dir, self.model, args.arch, optimizer, self.L, epoch, lowest=True) epoch += 1 print( f"optimization completed. Saving model to {os.path.join(self.checkpoint_dir,'model_final.pth.tar')}" ) torch.save(self.model, os.path.join(self.checkpoint_dir, 'model_final.pth.tar')) return self.model
def optimize(self, model, train_loader): """Perform full optimization.""" first_epoch = 0 model = model.to(self.dev) self.optimize_times = [0] optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), weight_decay=self.weight_decay, momentum=self.momentum, lr=self.lr) if self.checkpoint_dir is not None and self.resume: self.L, first_epoch = files.load_checkpoint_all( self.checkpoint_dir, model=None, opt=None) print('loaded from: ', self.checkpoint_dir, flush=True) print('first five entries of L: ', self.L[:5], flush=True) print('found first epoch to be', first_epoch, flush=True) first_epoch = 0 self.optimize_times = [0] self.L = self.L.cuda() print("model.headcount ", model.headcount, flush=True) ##################################################################################### # Perform optmization ############################################################### lowest_loss = 1e9 epoch = first_epoch while epoch < (self.num_epochs + 1): if not args.val_only: m = self.optimize_epoch(model, optimizer, train_loader, epoch, validation=False) if m['loss'] < lowest_loss: lowest_loss = m['loss'] files.save_checkpoint_all(self.checkpoint_dir, model, args.arch, optimizer, self.L, epoch, lowest=True) else: print('=' * 30 + ' doing only validation ' + "=" * 30) epoch = self.num_epochs m = self.optimize_epoch(model, optimizer, self.val_loader, epoch, validation=True) epoch += 1 print( f"Model optimization completed. Saving final model to {os.path.join(self.checkpoint_dir, 'model_final.pth.tar')}" ) torch.save(model, os.path.join(self.checkpoint_dir, 'model_final.pth.tar')) return model
loss = loss_func(pre_label, label) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() if iter % 25 == 0: print('epoch:{}, loss:{:.4f}'.format(epoch, loss.item())) writer.add_scalar('loss', loss.item(), iter + epoch * len(train_loader)) writer.add_scalar('lr', lr, iter + epoch * len(train_loader)) # 保存checkpoints files.save_checkpoint_all(checkpoint_dir, model, 'alexnet', optimizer, pre_label, epoch, lowest=False) _, predicted = torch.max(pre_label.data, 1) total = label.size(0) correct = (predicted == label).sum() print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct // total))) writer.add_scalar('correct', (100 * correct // total), (epoch + 1) * len(train_loader)) writer.close()
def optimize_epoch(self, model, optimizer, loader, epoch, validation=False): print(f"Starting epoch {epoch}, validation: {validation} " + "=" * 30) loss_value = AverageMeter() rotacc_value = AverageMeter() # house keeping if not validation: model.run() lr = self.lr_schedule(epoch) for pg in optimizer.param_groups: pg['lr'] = lr else: model.eval() XE = torch.nn.CrossEntropyLoss().to(self.dev) l_dl = 0 # len(loader) now = time.time() batch_time = MovingAverage(intertia=0.9) for iter, (data, label, selected) in enumerate(loader): now = time.time() if not validation: niter = epoch * len(loader.dataset) + iter * args.batch_size data = data.to(self.dev) mass = data.size(0) where = np.arange(mass, dtype=int) * 4 data = data.view(mass * 4, 3, data.size(3), data.size(4)) rotlabel = torch.tensor(range(4)).view(-1, 1).repeat(mass, 1).view(-1).to(self.dev) #################### train CNN ########################################### if not validation: final = model(data) if args.onlyrot: loss = torch.Tensor([0]).to(self.dev) else: if args.hc == 1: loss = XE(final[0][where], self.L[selected]) else: loss = torch.mean( torch.stack([XE(final[k][where], self.L[k, selected]) for k in range(args.hc)])) rotloss = XE(final[-1], rotlabel) pred = torch.argmax(final[-1], 1) total_loss = loss + rotloss optimizer.zero_grad() total_loss.backward() optimizer.step() correct = (pred == rotlabel).to(torch.float) rotacc = correct.sum() / float(mass) else: final = model(data) pred = torch.argmax(final[-1], 1) correct = (pred == rotlabel.cuda()).to(torch.float) rotacc = correct.sum() / float(mass) total_loss = torch.Tensor([0]) loss = torch.Tensor([0]) rotloss = torch.Tensor([0]) rotacc_value.update(rotacc.item(), mass) loss_value.update(total_loss.item(), mass) batch_time.update(time.time() - now) now = time.time() print( f"Loss: {loss_value.avg:03.3f}, RotAcc: {rotacc_value.avg:03.3f} | {epoch: 3}/{iter:05}/{l_dl:05} Freq: {mass / batch_time.avg:04.1f}Hz:", end='\r', flush=True) # every few iter logging if iter % args.logiter == 0: if not validation: print(niter, f" Loss: {loss.item():.3f}", flush=True) with torch.no_grad(): if not args.onlyrot: pred = torch.argmax(final[0][where], dim=1) pseudoloss = XE(final[0][where], pred) if not args.onlyrot: self.writer.add_scalar('Pseudoloss', pseudoloss.item(), niter) self.writer.add_scalar('lr', self.lr_schedule(epoch), niter) self.writer.add_scalar('Loss', loss.item(), niter) self.writer.add_scalar('RotLoss', rotloss.item(), niter) self.writer.add_scalar('RotAcc', rotacc.item(), niter) if iter > 0: self.writer.add_scalar('Freq(Hz)', mass / (time.time() - now), niter) # end of epoch logging if self.writer and (epoch % self.log_interval == 0): write_conv(self.writer, model, epoch) if validation: print('val Rot-Acc: ', rotacc_value.avg) self.writer.add_scalar('val Rot-Acc', rotacc_value.avg, epoch) files.save_checkpoint_all(self.checkpoint_dir, model, args.arch, optimizer, self.L, epoch, lowest=False) return {'loss': loss_value.avg}