def train(args, model, train_loader, criterion, optimizer, lr_schedule, trlog): model.train() tl = Averager() ta = Averager() global global_count, writer for i, (data, label) in enumerate(train_loader, 1): global_count = global_count + 1 if torch.cuda.is_available(): data, label = data.cuda(), label.long().cuda() elif not args.mixup: label = label.long() optimizer.zero_grad() logits = model(data) loss = criterion(logits, label) loss.backward() optimizer.step() lr_schedule.step() writer.add_scalar('data/loss', float(loss), global_count) tl.add(loss.item()) if not args.mixup: acc = count_acc(logits, label) ta.add(acc) writer.add_scalar('data/acc', float(acc), global_count) if (i - 1) % 100 == 0 or i == len(train_loader): if not args.mixup: print( 'epoch {}, train {}/{}, lr={:.5f}, loss={:.4f} acc={:.4f}'. format(epoch, i, len(train_loader), optimizer.param_groups[0]['lr'], loss.item(), acc)) else: print('epoch {}, train {}/{}, lr={:.5f}, loss={:.4f}'.format( epoch, i, len(train_loader), optimizer.param_groups[0]['lr'], loss.item())) if trlog is not None: tl = tl.item() trlog['train_loss'].append(tl) if not args.mixup: ta = ta.item() trlog['train_acc'].append(ta) else: trlog['train_acc'].append(0) return model, trlog else: return model
def __init__(self, args): if args.dataset == 'CUB': self.VAL_SETTING = [(5, 1), (5, 5), (5, 20)] else: self.VAL_SETTING = [(5, 1), (5, 5), (5, 20), (5, 50)] if args.eval_dataset == 'CUB': self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20)] else: self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20), (5, 50)] self.args = args # ensure_path( # self.args.save_path, # scripts_to_save=['model/models', 'model/networks', __file__], # ) self.logger = Logger(args, osp.join(args.save_path)) self.train_step = 0 self.train_epoch = 0 self.max_steps = args.episodes_per_epoch * args.max_epoch self.dt, self.ft = Averager(), Averager() self.bt, self.ot = Averager(), Averager() self.timer = Timer() # train statistics self.trlog = {} self.trlog['max_acc'] = 0.0 self.trlog['max_acc_epoch'] = 0 self.trlog['max_acc_interval'] = 0.0
def __init__(self, args): self.args = args # ensure_path( # self.args.save_path, # scripts_to_save=['model/models', 'model/networks', __file__], # ) self.logger = Logger(args, osp.join(args.save_path)) self.train_step = 0 self.train_epoch = 0 self.max_steps = args.episodes_per_epoch * args.max_epoch self.dt, self.ft = Averager(), Averager() self.bt, self.ot = Averager(), Averager() self.timer = Timer() # train statistics self.trlog = {} self.trlog['max_acc'] = 0.0 self.trlog['max_acc_epoch'] = 0 self.trlog['max_acc_interval'] = 0.0 # For tst if args.tst_free: self.trlog['max_tst_criterion'] = 0.0 self.trlog['max_tst_criterion_interval'] = 0. self.trlog['max_tst_criterion_epoch'] = 0 self.trlog['tst_criterion'] = args.tst_criterion
def __init__(self, args): self.args = args self.logger = Logger(args, osp.join(args.save_path)) self.train_step = 0 self.train_epoch = 0 self.max_steps = args.episodes_per_epoch * args.max_epoch self.dt, self.ft = Averager(), Averager() self.bt, self.ot = Averager(), Averager() self.timer = Timer() # train statistics self.trlog = {} self.trlog['max_auc'] = 0.0 self.trlog['max_auc_epoch'] = 0 self.trlog['max_auc_interval'] = 0.0
def __init__(self, args): self.args = args ensure_path( self.args.save_path, scripts_to_save=['model/models', 'model/networks', __file__], ) self.logger = Logger(args, osp.join(args.save_path)) self.train_step = 0 self.train_epoch = 0 self.dt, self.ft = Averager(), Averager() self.bt, self.ot = Averager(), Averager() self.timer = Timer() # train statistics self.trlog = {} self.trlog['max_acc'] = 0.0 self.trlog['max_acc_epoch'] = 0 self.trlog['max_acc_interval'] = 0.0
def train_tst(self): args = self.args self.model.train() if self.args.fix_BN: self.model.encoder.eval() # Clear evaluation file with open(osp.join(self.args.save_path, 'eval.jl'), 'w') as fp: pass # start FSL training label, label_aux = self.prepare_label() for epoch in range(1, args.max_epoch + 1): self.train_epoch += 1 self.model.train() if self.args.fix_BN: self.model.encoder.eval() tl1 = Averager() tl2 = Averager() ta = Averager() start_tm = time.time() for batch in self.train_loader: self.train_step += 1 if torch.cuda.is_available(): data, gt_label = [_.cuda() for _ in batch] else: data, gt_label = batch[0], batch[1] data_tm = time.time() self.dt.add(data_tm - start_tm) # get saved centers logits, reg_logits = self.para_model(data) if reg_logits is not None: loss = F.cross_entropy(logits, label) total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux) else: loss = F.cross_entropy(logits, label) total_loss = F.cross_entropy(logits, label) tl2.add(loss) forward_tm = time.time() self.ft.add(forward_tm - data_tm) acc = count_acc(logits, label) tl1.add(total_loss.item()) ta.add(acc) self.optimizer.zero_grad() total_loss.backward() backward_tm = time.time() self.bt.add(backward_tm - forward_tm) self.optimizer.step() optimizer_tm = time.time() self.ot.add(optimizer_tm - backward_tm) # refresh start_tm start_tm = time.time() if args.debug_fast: print('Debug fast, breaking training after 1 mini-batch') break self.lr_scheduler.step() self.try_evaluate_tst(epoch) print('ETA:{}/{}'.format( self.timer.measure(), self.timer.measure(self.train_epoch / args.max_epoch)) ) torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
class Trainer(object, metaclass=abc.ABCMeta): def __init__(self, args): self.args = args # ensure_path( # self.args.save_path, # scripts_to_save=['model/models', 'model/networks', __file__], # ) self.logger = Logger(args, osp.join(args.save_path)) self.train_step = 0 self.train_epoch = 0 self.max_steps = args.episodes_per_epoch * args.max_epoch self.dt, self.ft = Averager(), Averager() self.bt, self.ot = Averager(), Averager() self.timer = Timer() # train statistics self.trlog = {} self.trlog['max_acc'] = 0.0 self.trlog['max_acc_epoch'] = 0 self.trlog['max_acc_interval'] = 0.0 @abc.abstractmethod def train(self): pass @abc.abstractmethod def evaluate(self, data_loader): pass @abc.abstractmethod def evaluate_test(self, data_loader): pass @abc.abstractmethod def final_record(self): pass def try_evaluate(self, epoch): args = self.args if self.train_epoch % args.eval_interval == 0: vl, va, vap = self.evaluate(self.val_loader) self.logger.add_scalar('val_loss', float(vl), self.train_epoch) self.logger.add_scalar('val_acc', float(va), self.train_epoch) print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format( epoch, vl, va, vap)) if va >= self.trlog['max_acc']: self.trlog['max_acc'] = va self.trlog['max_acc_interval'] = vap self.trlog['max_acc_epoch'] = self.train_epoch self.save_model('max_acc') def try_logging(self, tl1, tl2, ta, tg=None): args = self.args if self.train_step % args.log_interval == 0: print( 'epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}' .format(self.train_epoch, self.train_step, self.max_steps, tl1.item(), tl2.item(), ta.item(), self.optimizer.param_groups[0]['lr'])) self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step) self.logger.add_scalar('train_loss', tl2.item(), self.train_step) self.logger.add_scalar('train_acc', ta.item(), self.train_step) if tg is not None: self.logger.add_scalar('grad_norm', tg.item(), self.train_step) print('data_timer: {:.2f} sec, ' \ 'forward_timer: {:.2f} sec,' \ 'backward_timer: {:.2f} sec, ' \ 'optim_timer: {:.2f} sec'.format( self.dt.item(), self.ft.item(), self.bt.item(), self.ot.item()) ) self.logger.dump() def save_model(self, name): torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth')) def __str__(self): return "{}({})".format(self.__class__.__name__, self.model.__class__.__name__)
def train(self): args = self.args # start GFSL training for epoch in range(1, args.max_epoch + 1): self.train_epoch += 1 self.model.train() tl1 = Averager() tl2 = Averager() ta = Averager() start_tm = time.time() for _, batch in enumerate( zip(self.train_fsl_loader, self.train_gfsl_loader)): self.train_step += 1 if torch.cuda.is_available(): support_data, support_label = batch[0][0].cuda( ), batch[0][1].cuda() query_data, query_label = batch[1][0].cuda( ), batch[1][1].cuda() else: support_data, support_label = batch[0][0], batch[0][1] query_data, query_label = batch[1][0], batch[1][1] data_tm = time.time() self.dt.add(data_tm - start_tm) logits = self.model(support_data, query_data, support_label) loss = F.cross_entropy( logits, query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1)) tl2.add(loss.item()) forward_tm = time.time() self.ft.add(forward_tm - data_tm) acc = count_acc( logits, query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1)) tl1.add(loss.item()) ta.add(acc) self.optimizer.zero_grad() loss.backward() backward_tm = time.time() self.bt.add(backward_tm - forward_tm) self.optimizer.step() self.lr_scheduler.step() optimizer_tm = time.time() self.ot.add(optimizer_tm - backward_tm) self.try_logging(tl1, tl2, ta) # refresh start_tm start_tm = time.time() del logits, loss torch.cuda.empty_cache() self.try_evaluate(epoch) print('ETA:{}/{}'.format( self.timer.measure(), self.timer.measure(self.train_epoch / args.max_epoch))) torch.save(self.trlog, osp.join(args.save_path, 'trlog')) self.save_model('epoch-last')
class Trainer(object, metaclass=abc.ABCMeta): def __init__(self, args): if args.dataset == 'CUB': self.VAL_SETTING = [(5, 1), (5, 5), (5, 20)] else: self.VAL_SETTING = [(5, 1), (5, 5), (5, 20), (5, 50)] if args.eval_dataset == 'CUB': self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20)] else: self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20), (5, 50)] self.args = args # ensure_path( # self.args.save_path, # scripts_to_save=['model/models', 'model/networks', __file__], # ) self.logger = Logger(args, osp.join(args.save_path)) self.train_step = 0 self.train_epoch = 0 self.max_steps = args.episodes_per_epoch * args.max_epoch self.dt, self.ft = Averager(), Averager() self.bt, self.ot = Averager(), Averager() self.timer = Timer() # train statistics self.trlog = {} self.trlog['max_acc'] = 0.0 self.trlog['max_acc_epoch'] = 0 self.trlog['max_acc_interval'] = 0.0 @abc.abstractmethod def train(self): pass @abc.abstractmethod def evaluate(self, data_loader): pass @abc.abstractmethod def evaluate_test(self, data_loader): pass @abc.abstractmethod def final_record(self): pass def try_evaluate(self, epoch): args = self.args if self.train_epoch % args.eval_interval == 0: if args.eval_all: for i, (args.eval_way, args.eval_shot) in enumerate(self.VAL_SETTING): if i == 0: vl, va, vap = self.eval_process(args, epoch) else: self.eval_process(args, epoch) else: vl, va, vap = self.eval_process(args, epoch) if va >= self.trlog['max_acc']: self.trlog['max_acc'] = va self.trlog['max_acc_interval'] = vap self.trlog['max_acc_epoch'] = self.train_epoch self.save_model('max_acc') print('best epoch {}, best val acc={:.4f} + {:.4f}'.format( self.trlog['max_acc_epoch'], self.trlog['max_acc'], self.trlog['max_acc_interval'])) def eval_process(self, args, epoch): valset = self.valset if args.model_class in ['QsimProtoNet', 'QsimMatchNet']: val_sampler = NegativeSampler(args, valset.label, args.num_eval_episodes, args.eval_way, args.eval_shot + args.eval_query) else: val_sampler = CategoriesSampler(valset.label, args.num_eval_episodes, args.eval_way, args.eval_shot + args.eval_query) val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=args.num_workers, pin_memory=True) vl, va, vap = self.evaluate(val_loader) self.logger.add_scalar('%dw%ds_val_loss' % (args.eval_way, args.eval_shot), float(vl), self.train_epoch) self.logger.add_scalar('%dw%ds_val_acc' % (args.eval_way, args.eval_shot), float(va), self.train_epoch) print('epoch {},{} way {} shot, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(epoch, args.eval_way, args.eval_shot, vl, va, vap)) return vl, va, vap def try_logging(self, tl1, tl2, ta, tg=None): args = self.args if self.train_step % args.log_interval == 0: print('epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}' .format(self.train_epoch, self.train_step, self.max_steps, tl1.item(), tl2.item(), ta.item(), self.optimizer.param_groups[0]['lr'])) self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step) self.logger.add_scalar('train_loss', tl2.item(), self.train_step) self.logger.add_scalar('train_acc', ta.item(), self.train_step) if tg is not None: self.logger.add_scalar('grad_norm', tg.item(), self.train_step) print('data_timer: {:.2f} sec, ' \ 'forward_timer: {:.2f} sec,' \ 'backward_timer: {:.2f} sec, ' \ 'optim_timer: {:.2f} sec'.format( self.dt.item(), self.ft.item(), self.bt.item(), self.ot.item()) ) self.logger.dump() def save_model(self, name): torch.save( dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth') ) def __str__(self): return "{}({})".format( self.__class__.__name__, self.model.__class__.__name__ )
def validate(args, model, val_loader, epoch, trlog=None): model.eval() global writer vl_dist, va_dist, vl_sim, va_sim = Averager(), Averager(), Averager( ), Averager() if trlog is not None: print('[Dist] best epoch {}, current best val acc={:.4f}'.format( trlog['max_acc_dist_epoch'], trlog['max_acc_dist'])) print('[Sim] best epoch {}, current best val acc={:.4f}'.format( trlog['max_acc_sim_epoch'], trlog['max_acc_sim'])) # test performance with Few-Shot label = torch.arange(args.num_val_class).repeat(args.query).long() if torch.cuda.is_available(): label = label.cuda() with torch.no_grad(): for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data, _ = batch data_shot, data_query = data[:args.num_val_class], data[ args.num_val_class:] # 16-way test logits_dist, logits_sim = model.forward_proto( data_shot, data_query, args.num_val_class) loss_dist = F.cross_entropy(logits_dist, label) acc_dist = count_acc(logits_dist, label) loss_sim = F.cross_entropy(logits_sim, label) acc_sim = count_acc(logits_sim, label) vl_dist.add(loss_dist.item()) va_dist.add(acc_dist) vl_sim.add(loss_sim.item()) va_sim.add(acc_sim) vl_dist = vl_dist.item() va_dist = va_dist.item() vl_sim = vl_sim.item() va_sim = va_sim.item() print( 'epoch {}, val, loss_dist={:.4f} acc_dist={:.4f} loss_sim={:.4f} acc_sim={:.4f}' .format(epoch, vl_dist, va_dist, vl_sim, va_sim)) if trlog is not None: writer.add_scalar('data/val_loss_dist', float(vl_dist), epoch) writer.add_scalar('data/val_acc_dist', float(va_dist), epoch) writer.add_scalar('data/val_loss_sim', float(vl_sim), epoch) writer.add_scalar('data/val_acc_sim', float(va_sim), epoch) if va_dist > trlog['max_acc_dist']: trlog['max_acc_dist'] = va_dist trlog['max_acc_dist_epoch'] = epoch save_model('max_acc_dist', model, args) save_checkpoint(True) if va_sim > trlog['max_acc_sim']: trlog['max_acc_sim'] = va_sim trlog['max_acc_sim_epoch'] = epoch save_model('max_acc_sim', model, args) save_checkpoint(True) trlog['val_loss_dist'].append(vl_dist) trlog['val_acc_dist'].append(va_dist) trlog['val_loss_sim'].append(vl_sim) trlog['val_acc_sim'].append(va_sim) return trlog
trlog['max_acc_sim'] = 0.0 trlog['max_acc_sim_epoch'] = 0 initial_lr = args.lr global_count = 0 timer = Timer() writer = SummaryWriter(logdir=args.save_path) for epoch in range(init_epoch, args.max_epoch + 1): # refine the step-size if epoch in args.schedule: initial_lr *= args.gamma for param_group in optimizer.param_groups: param_group['lr'] = initial_lr model.train() tl = Averager() ta = Averager() for i, batch in enumerate(train_loader, 1): global_count = global_count + 1 if torch.cuda.is_available(): data, label = [_.cuda() for _ in batch] label = label.type(torch.cuda.LongTensor) else: data, label = batch label = label.type(torch.LongTensor) logits = model(data) loss = criterion(logits, label) acc = count_acc(logits, label) writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count)
pin_memory=True) # test_set = Dataset('test', args.unsupervised, args) test_set = get_dataset(n, 'test', args.unsupervised, args) sampler = CategoriesSampler(test_set.label, args.num_test_episodes, args.way, args.shot + args.query) loader = DataLoader(dataset=test_set, batch_sampler=sampler, num_workers=8, pin_memory=True) shot_label = torch.arange(min(args.way, valset.num_class)).repeat( args.shot).numpy() query_label = torch.arange(min(args.way, valset.num_class)).repeat( args.query).numpy() val_acc_record = np.zeros((500, len(c_list))) ave_acc = Averager() with torch.no_grad(): for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader), desc='val eval'): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data = batch[0] data_emb = model(data, is_emb=True) if args.centralize: data_emb = data_emb - class_mean if args.normalize: data_emb = F.normalize(data_emb, dim=1, p=2)
class Trainer(object, metaclass=abc.ABCMeta): def __init__(self, args): self.args = args # ensure_path( # self.args.save_path, # scripts_to_save=['model/models', 'model/networks', __file__], # ) self.logger = Logger(args, osp.join(args.save_path)) self.train_step = 0 self.train_epoch = 0 self.max_steps = args.episodes_per_epoch * args.max_epoch self.dt, self.ft = Averager(), Averager() self.bt, self.ot = Averager(), Averager() self.timer = Timer() # train statistics self.trlog = {} self.trlog['max_acc'] = 0.0 self.trlog['max_acc_epoch'] = 0 self.trlog['max_acc_interval'] = 0.0 # For tst if args.tst_free: self.trlog['max_tst_criterion'] = 0.0 self.trlog['max_tst_criterion_interval'] = 0. self.trlog['max_tst_criterion_epoch'] = 0 self.trlog['tst_criterion'] = args.tst_criterion @abc.abstractmethod def train(self): pass @abc.abstractmethod def evaluate(self, data_loader): pass @abc.abstractmethod def evaluate_test(self, data_loader): pass @abc.abstractmethod def final_record(self): pass def print_metric_summaries(self, metric_summaries, prefix='\t'): for key, (mean, std) in metric_summaries.items(): print('{}{}: {:.4f} +/- {:.4f}'.format(prefix, key, mean, std)) def log_metric_summaries(self, metric_summaries, epoch, prefix=''): for key, (mean, std) in metric_summaries.items(): self.logger.add_scalar('{}{}'.format(prefix, key), mean, epoch) def try_evaluate(self, epoch): args = self.args if self.train_epoch % args.eval_interval == 0: if not args.tst_free: vl, va, vap = self.evaluate(self.val_loader) self.logger.add_scalar('val_loss', float(vl), self.train_epoch) self.logger.add_scalar('val_acc', float(va), self.train_epoch) print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format( epoch, vl, va, vap)) else: vl, va, vap, metrics = self.evaluate(self.val_loader) self.logger.add_scalar('val_loss', float(vl), self.train_epoch) self.logger.add_scalar('val_acc', float(va), self.train_epoch) print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format( epoch, vl, va, vap)) self.print_metric_summaries(metrics, prefix='\tval_') self.log_metric_summaries(metrics, epoch=epoch, prefix='val_') if va >= self.trlog['max_acc']: self.trlog['max_acc'] = va self.trlog['max_acc_interval'] = vap self.trlog['max_acc_epoch'] = self.train_epoch self.save_model('max_acc') # Probably a different criterion for TST -> optimize here. if args.tst_free and args.tst_criterion: assert args.tst_criterion in metrics, 'Criterion {} not found in {}'.format( args.tst_criterion, metrics.keys()) criterion, criterion_interval = metrics[args.tst_criterion] if criterion >= self.trlog['max_tst_criterion']: self.trlog['max_tst_criterion'] = criterion self.trlog[ 'max_tst_criterion_interval'] = criterion_interval self.trlog['max_tst_criterion_epoch'] = self.train_epoch self.save_model('max_tst_criterion') print( 'Found new best model at Epoch {} : Validation {} = {:.4f} +/- {:4f}' .format(self.train_epoch, args.tst_criterion, criterion, criterion_interval)) def try_logging(self, tl1, tl2, ta, tg=None): args = self.args if self.train_step % args.log_interval == 0: print( 'epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}' .format(self.train_epoch, self.train_step, self.max_steps, tl1.item(), tl2.item(), ta.item(), self.optimizer.param_groups[0]['lr'])) self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step) self.logger.add_scalar('train_loss', tl2.item(), self.train_step) self.logger.add_scalar('train_acc', ta.item(), self.train_step) if tg is not None: self.logger.add_scalar('grad_norm', tg.item(), self.train_step) print('data_timer: {:.2f} sec, ' \ 'forward_timer: {:.2f} sec,' \ 'backward_timer: {:.2f} sec, ' \ 'optim_timer: {:.2f} sec'.format( self.dt.item(), self.ft.item(), self.bt.item(), self.ot.item()) ) self.logger.dump() def save_model(self, name): torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth')) def __str__(self): return "{}({})".format(self.__class__.__name__, self.model.__class__.__name__)
def train(self): args = self.args self.model.train() if self.args.fix_BN: self.model.encoder.eval() # start FSL training label, label_aux = self.prepare_label() for epoch in range(1, args.max_epoch + 1): self.train_epoch += 1 self.model.train() if self.args.fix_BN: self.model.encoder.eval() tl1 = Averager() tl2 = Averager() ta = Averager() start_tm = time.time() for batch in tqdm(self.train_loader): data, gt_label = [_.cuda() for _ in batch] data_tm = time.time() self.dt.add(data_tm - start_tm) # get saved centers logits, reg_logits = self.para_model(data) logits = logits.view(-1, args.way) oh_query = torch.nn.functional.one_hot(label, args.way) sims = logits temp = (sims * oh_query).sum(-1) e_sim_p = temp - self.model.margin e_sim_p_pos = F.relu(e_sim_p) e_sim_p_neg = F.relu(-e_sim_p) l_open_margin = e_sim_p_pos.mean(-1) l_open = e_sim_p_neg.mean(-1) if reg_logits is not None: loss = F.cross_entropy(logits, label) total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux) else: loss = F.cross_entropy(logits, label) total_loss = total_loss + args.open_balance * l_open tl2.add(loss) forward_tm = time.time() self.ft.add(forward_tm - data_tm) acc = count_acc(logits, label) tl1.add(total_loss.item()) ta.add(acc) self.optimizer.zero_grad() total_loss.backward(retain_graph=True) self.optimizer_margin.zero_grad() l_open_margin.backward() self.optimizer.step() self.optimizer_margin.step() backward_tm = time.time() self.bt.add(backward_tm - forward_tm) optimizer_tm = time.time() self.ot.add(optimizer_tm - backward_tm) # refresh start_tm start_tm = time.time() print('lr: {:.4f} Total_loss: {:.4f} ce_loss {:.4f} l_open: {:4f} R: {:4f}\n'.format(self.optimizer_margin.param_groups[0]['lr'],\ total_loss.item(), loss.item(), l_open.item(), self.model.margin.item())) self.lr_scheduler.step() self.lr_scheduler_margin.step() self.try_evaluate(epoch) print('ETA:{}/{}'.format( self.timer.measure(), self.timer.measure(self.train_epoch / args.max_epoch)) ) torch.save(self.trlog, osp.join(args.save_path, 'trlog')) self.save_model('epoch-last')