def train(args, model, train_loader, criterion, optimizer, lr_schedule, trlog): model.train() tl = Averager() ta = Averager() global global_count, writer for i, (data, label) in enumerate(train_loader, 1): global_count = global_count + 1 if torch.cuda.is_available(): data, label = data.cuda(), label.long().cuda() elif not args.mixup: label = label.long() optimizer.zero_grad() logits = model(data) loss = criterion(logits, label) loss.backward() optimizer.step() lr_schedule.step() writer.add_scalar('data/loss', float(loss), global_count) tl.add(loss.item()) if not args.mixup: acc = count_acc(logits, label) ta.add(acc) writer.add_scalar('data/acc', float(acc), global_count) if (i - 1) % 100 == 0 or i == len(train_loader): if not args.mixup: print( 'epoch {}, train {}/{}, lr={:.5f}, loss={:.4f} acc={:.4f}'. format(epoch, i, len(train_loader), optimizer.param_groups[0]['lr'], loss.item(), acc)) else: print('epoch {}, train {}/{}, lr={:.5f}, loss={:.4f}'.format( epoch, i, len(train_loader), optimizer.param_groups[0]['lr'], loss.item())) if trlog is not None: tl = tl.item() trlog['train_loss'].append(tl) if not args.mixup: ta = ta.item() trlog['train_acc'].append(ta) else: trlog['train_acc'].append(0) return model, trlog else: return model
class Trainer(object, metaclass=abc.ABCMeta): def __init__(self, args): self.args = args # ensure_path( # self.args.save_path, # scripts_to_save=['model/models', 'model/networks', __file__], # ) self.logger = Logger(args, osp.join(args.save_path)) self.train_step = 0 self.train_epoch = 0 self.max_steps = args.episodes_per_epoch * args.max_epoch self.dt, self.ft = Averager(), Averager() self.bt, self.ot = Averager(), Averager() self.timer = Timer() # train statistics self.trlog = {} self.trlog['max_acc'] = 0.0 self.trlog['max_acc_epoch'] = 0 self.trlog['max_acc_interval'] = 0.0 @abc.abstractmethod def train(self): pass @abc.abstractmethod def evaluate(self, data_loader): pass @abc.abstractmethod def evaluate_test(self, data_loader): pass @abc.abstractmethod def final_record(self): pass def try_evaluate(self, epoch): args = self.args if self.train_epoch % args.eval_interval == 0: vl, va, vap = self.evaluate(self.val_loader) self.logger.add_scalar('val_loss', float(vl), self.train_epoch) self.logger.add_scalar('val_acc', float(va), self.train_epoch) print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format( epoch, vl, va, vap)) if va >= self.trlog['max_acc']: self.trlog['max_acc'] = va self.trlog['max_acc_interval'] = vap self.trlog['max_acc_epoch'] = self.train_epoch self.save_model('max_acc') def try_logging(self, tl1, tl2, ta, tg=None): args = self.args if self.train_step % args.log_interval == 0: print( 'epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}' .format(self.train_epoch, self.train_step, self.max_steps, tl1.item(), tl2.item(), ta.item(), self.optimizer.param_groups[0]['lr'])) self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step) self.logger.add_scalar('train_loss', tl2.item(), self.train_step) self.logger.add_scalar('train_acc', ta.item(), self.train_step) if tg is not None: self.logger.add_scalar('grad_norm', tg.item(), self.train_step) print('data_timer: {:.2f} sec, ' \ 'forward_timer: {:.2f} sec,' \ 'backward_timer: {:.2f} sec, ' \ 'optim_timer: {:.2f} sec'.format( self.dt.item(), self.ft.item(), self.bt.item(), self.ot.item()) ) self.logger.dump() def save_model(self, name): torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth')) def __str__(self): return "{}({})".format(self.__class__.__name__, self.model.__class__.__name__)
class Trainer(object, metaclass=abc.ABCMeta): def __init__(self, args): if args.dataset == 'CUB': self.VAL_SETTING = [(5, 1), (5, 5), (5, 20)] else: self.VAL_SETTING = [(5, 1), (5, 5), (5, 20), (5, 50)] if args.eval_dataset == 'CUB': self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20)] else: self.TEST_SETTINGS = [(5, 1), (5, 5), (5, 20), (5, 50)] self.args = args # ensure_path( # self.args.save_path, # scripts_to_save=['model/models', 'model/networks', __file__], # ) self.logger = Logger(args, osp.join(args.save_path)) self.train_step = 0 self.train_epoch = 0 self.max_steps = args.episodes_per_epoch * args.max_epoch self.dt, self.ft = Averager(), Averager() self.bt, self.ot = Averager(), Averager() self.timer = Timer() # train statistics self.trlog = {} self.trlog['max_acc'] = 0.0 self.trlog['max_acc_epoch'] = 0 self.trlog['max_acc_interval'] = 0.0 @abc.abstractmethod def train(self): pass @abc.abstractmethod def evaluate(self, data_loader): pass @abc.abstractmethod def evaluate_test(self, data_loader): pass @abc.abstractmethod def final_record(self): pass def try_evaluate(self, epoch): args = self.args if self.train_epoch % args.eval_interval == 0: if args.eval_all: for i, (args.eval_way, args.eval_shot) in enumerate(self.VAL_SETTING): if i == 0: vl, va, vap = self.eval_process(args, epoch) else: self.eval_process(args, epoch) else: vl, va, vap = self.eval_process(args, epoch) if va >= self.trlog['max_acc']: self.trlog['max_acc'] = va self.trlog['max_acc_interval'] = vap self.trlog['max_acc_epoch'] = self.train_epoch self.save_model('max_acc') print('best epoch {}, best val acc={:.4f} + {:.4f}'.format( self.trlog['max_acc_epoch'], self.trlog['max_acc'], self.trlog['max_acc_interval'])) def eval_process(self, args, epoch): valset = self.valset if args.model_class in ['QsimProtoNet', 'QsimMatchNet']: val_sampler = NegativeSampler(args, valset.label, args.num_eval_episodes, args.eval_way, args.eval_shot + args.eval_query) else: val_sampler = CategoriesSampler(valset.label, args.num_eval_episodes, args.eval_way, args.eval_shot + args.eval_query) val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=args.num_workers, pin_memory=True) vl, va, vap = self.evaluate(val_loader) self.logger.add_scalar('%dw%ds_val_loss' % (args.eval_way, args.eval_shot), float(vl), self.train_epoch) self.logger.add_scalar('%dw%ds_val_acc' % (args.eval_way, args.eval_shot), float(va), self.train_epoch) print('epoch {},{} way {} shot, val, loss={:.4f} acc={:.4f}+{:.4f}'.format(epoch, args.eval_way, args.eval_shot, vl, va, vap)) return vl, va, vap def try_logging(self, tl1, tl2, ta, tg=None): args = self.args if self.train_step % args.log_interval == 0: print('epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}' .format(self.train_epoch, self.train_step, self.max_steps, tl1.item(), tl2.item(), ta.item(), self.optimizer.param_groups[0]['lr'])) self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step) self.logger.add_scalar('train_loss', tl2.item(), self.train_step) self.logger.add_scalar('train_acc', ta.item(), self.train_step) if tg is not None: self.logger.add_scalar('grad_norm', tg.item(), self.train_step) print('data_timer: {:.2f} sec, ' \ 'forward_timer: {:.2f} sec,' \ 'backward_timer: {:.2f} sec, ' \ 'optim_timer: {:.2f} sec'.format( self.dt.item(), self.ft.item(), self.bt.item(), self.ot.item()) ) self.logger.dump() def save_model(self, name): torch.save( dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth') ) def __str__(self): return "{}({})".format( self.__class__.__name__, self.model.__class__.__name__ )
def validate(args, model, val_loader, epoch, trlog=None): model.eval() global writer vl_dist, va_dist, vl_sim, va_sim = Averager(), Averager(), Averager( ), Averager() if trlog is not None: print('[Dist] best epoch {}, current best val acc={:.4f}'.format( trlog['max_acc_dist_epoch'], trlog['max_acc_dist'])) print('[Sim] best epoch {}, current best val acc={:.4f}'.format( trlog['max_acc_sim_epoch'], trlog['max_acc_sim'])) # test performance with Few-Shot label = torch.arange(args.num_val_class).repeat(args.query).long() if torch.cuda.is_available(): label = label.cuda() with torch.no_grad(): for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data, _ = batch data_shot, data_query = data[:args.num_val_class], data[ args.num_val_class:] # 16-way test logits_dist, logits_sim = model.forward_proto( data_shot, data_query, args.num_val_class) loss_dist = F.cross_entropy(logits_dist, label) acc_dist = count_acc(logits_dist, label) loss_sim = F.cross_entropy(logits_sim, label) acc_sim = count_acc(logits_sim, label) vl_dist.add(loss_dist.item()) va_dist.add(acc_dist) vl_sim.add(loss_sim.item()) va_sim.add(acc_sim) vl_dist = vl_dist.item() va_dist = va_dist.item() vl_sim = vl_sim.item() va_sim = va_sim.item() print( 'epoch {}, val, loss_dist={:.4f} acc_dist={:.4f} loss_sim={:.4f} acc_sim={:.4f}' .format(epoch, vl_dist, va_dist, vl_sim, va_sim)) if trlog is not None: writer.add_scalar('data/val_loss_dist', float(vl_dist), epoch) writer.add_scalar('data/val_acc_dist', float(va_dist), epoch) writer.add_scalar('data/val_loss_sim', float(vl_sim), epoch) writer.add_scalar('data/val_acc_sim', float(va_sim), epoch) if va_dist > trlog['max_acc_dist']: trlog['max_acc_dist'] = va_dist trlog['max_acc_dist_epoch'] = epoch save_model('max_acc_dist', model, args) save_checkpoint(True) if va_sim > trlog['max_acc_sim']: trlog['max_acc_sim'] = va_sim trlog['max_acc_sim_epoch'] = epoch save_model('max_acc_sim', model, args) save_checkpoint(True) trlog['val_loss_dist'].append(vl_dist) trlog['val_acc_dist'].append(va_dist) trlog['val_loss_sim'].append(vl_sim) trlog['val_acc_sim'].append(va_sim) return trlog
loss = criterion(logits, label) acc = count_acc(logits, label) writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) if (i - 1) % 100 == 0: print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format( epoch, i, len(train_loader), loss.item(), acc)) tl.add(loss.item()) ta.add(acc) optimizer.zero_grad() loss.backward() optimizer.step() tl = tl.item() ta = ta.item() # do not do validation in first 500 epoches if epoch > 100 or (epoch - 1) % 5 == 0: model.eval() vl_dist = Averager() va_dist = Averager() vl_sim = Averager() va_sim = Averager() print('[Dist] best epoch {}, current best val acc={:.4f}'.format( trlog['max_acc_dist_epoch'], trlog['max_acc_dist'])) print('[Sim] best epoch {}, current best val acc={:.4f}'.format( trlog['max_acc_sim_epoch'], trlog['max_acc_sim'])) # test performance with Few-Shot label = torch.arange(valset.num_class).repeat(args.query)
class Trainer(object, metaclass=abc.ABCMeta): def __init__(self, args): self.args = args # ensure_path( # self.args.save_path, # scripts_to_save=['model/models', 'model/networks', __file__], # ) self.logger = Logger(args, osp.join(args.save_path)) self.train_step = 0 self.train_epoch = 0 self.max_steps = args.episodes_per_epoch * args.max_epoch self.dt, self.ft = Averager(), Averager() self.bt, self.ot = Averager(), Averager() self.timer = Timer() # train statistics self.trlog = {} self.trlog['max_acc'] = 0.0 self.trlog['max_acc_epoch'] = 0 self.trlog['max_acc_interval'] = 0.0 # For tst if args.tst_free: self.trlog['max_tst_criterion'] = 0.0 self.trlog['max_tst_criterion_interval'] = 0. self.trlog['max_tst_criterion_epoch'] = 0 self.trlog['tst_criterion'] = args.tst_criterion @abc.abstractmethod def train(self): pass @abc.abstractmethod def evaluate(self, data_loader): pass @abc.abstractmethod def evaluate_test(self, data_loader): pass @abc.abstractmethod def final_record(self): pass def print_metric_summaries(self, metric_summaries, prefix='\t'): for key, (mean, std) in metric_summaries.items(): print('{}{}: {:.4f} +/- {:.4f}'.format(prefix, key, mean, std)) def log_metric_summaries(self, metric_summaries, epoch, prefix=''): for key, (mean, std) in metric_summaries.items(): self.logger.add_scalar('{}{}'.format(prefix, key), mean, epoch) def try_evaluate(self, epoch): args = self.args if self.train_epoch % args.eval_interval == 0: if not args.tst_free: vl, va, vap = self.evaluate(self.val_loader) self.logger.add_scalar('val_loss', float(vl), self.train_epoch) self.logger.add_scalar('val_acc', float(va), self.train_epoch) print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format( epoch, vl, va, vap)) else: vl, va, vap, metrics = self.evaluate(self.val_loader) self.logger.add_scalar('val_loss', float(vl), self.train_epoch) self.logger.add_scalar('val_acc', float(va), self.train_epoch) print('epoch {}, val, loss={:.4f} acc={:.4f}+{:.4f}'.format( epoch, vl, va, vap)) self.print_metric_summaries(metrics, prefix='\tval_') self.log_metric_summaries(metrics, epoch=epoch, prefix='val_') if va >= self.trlog['max_acc']: self.trlog['max_acc'] = va self.trlog['max_acc_interval'] = vap self.trlog['max_acc_epoch'] = self.train_epoch self.save_model('max_acc') # Probably a different criterion for TST -> optimize here. if args.tst_free and args.tst_criterion: assert args.tst_criterion in metrics, 'Criterion {} not found in {}'.format( args.tst_criterion, metrics.keys()) criterion, criterion_interval = metrics[args.tst_criterion] if criterion >= self.trlog['max_tst_criterion']: self.trlog['max_tst_criterion'] = criterion self.trlog[ 'max_tst_criterion_interval'] = criterion_interval self.trlog['max_tst_criterion_epoch'] = self.train_epoch self.save_model('max_tst_criterion') print( 'Found new best model at Epoch {} : Validation {} = {:.4f} +/- {:4f}' .format(self.train_epoch, args.tst_criterion, criterion, criterion_interval)) def try_logging(self, tl1, tl2, ta, tg=None): args = self.args if self.train_step % args.log_interval == 0: print( 'epoch {}, train {:06g}/{:06g}, total loss={:.4f}, loss={:.4f} acc={:.4f}, lr={:.4g}' .format(self.train_epoch, self.train_step, self.max_steps, tl1.item(), tl2.item(), ta.item(), self.optimizer.param_groups[0]['lr'])) self.logger.add_scalar('train_total_loss', tl1.item(), self.train_step) self.logger.add_scalar('train_loss', tl2.item(), self.train_step) self.logger.add_scalar('train_acc', ta.item(), self.train_step) if tg is not None: self.logger.add_scalar('grad_norm', tg.item(), self.train_step) print('data_timer: {:.2f} sec, ' \ 'forward_timer: {:.2f} sec,' \ 'backward_timer: {:.2f} sec, ' \ 'optim_timer: {:.2f} sec'.format( self.dt.item(), self.ft.item(), self.bt.item(), self.ot.item()) ) self.logger.dump() def save_model(self, name): torch.save(dict(params=self.model.state_dict()), osp.join(self.args.save_path, name + '.pth')) def __str__(self): return "{}({})".format(self.__class__.__name__, self.model.__class__.__name__)