def train(args, model, train_loader, criterion, optimizer, lr_schedule, trlog): model.train() tl = Averager() ta = Averager() global global_count, writer for i, (data, label) in enumerate(train_loader, 1): global_count = global_count + 1 if torch.cuda.is_available(): data, label = data.cuda(), label.long().cuda() elif not args.mixup: label = label.long() optimizer.zero_grad() logits = model(data) loss = criterion(logits, label) loss.backward() optimizer.step() lr_schedule.step() writer.add_scalar('data/loss', float(loss), global_count) tl.add(loss.item()) if not args.mixup: acc = count_acc(logits, label) ta.add(acc) writer.add_scalar('data/acc', float(acc), global_count) if (i - 1) % 100 == 0 or i == len(train_loader): if not args.mixup: print( 'epoch {}, train {}/{}, lr={:.5f}, loss={:.4f} acc={:.4f}'. format(epoch, i, len(train_loader), optimizer.param_groups[0]['lr'], loss.item(), acc)) else: print('epoch {}, train {}/{}, lr={:.5f}, loss={:.4f}'.format( epoch, i, len(train_loader), optimizer.param_groups[0]['lr'], loss.item())) if trlog is not None: tl = tl.item() trlog['train_loss'].append(tl) if not args.mixup: ta = ta.item() trlog['train_acc'].append(ta) else: trlog['train_acc'].append(0) return model, trlog else: return model
def train_tst(self): args = self.args self.model.train() if self.args.fix_BN: self.model.encoder.eval() # Clear evaluation file with open(osp.join(self.args.save_path, 'eval.jl'), 'w') as fp: pass # start FSL training label, label_aux = self.prepare_label() for epoch in range(1, args.max_epoch + 1): self.train_epoch += 1 self.model.train() if self.args.fix_BN: self.model.encoder.eval() tl1 = Averager() tl2 = Averager() ta = Averager() start_tm = time.time() for batch in self.train_loader: self.train_step += 1 if torch.cuda.is_available(): data, gt_label = [_.cuda() for _ in batch] else: data, gt_label = batch[0], batch[1] data_tm = time.time() self.dt.add(data_tm - start_tm) # get saved centers logits, reg_logits = self.para_model(data) if reg_logits is not None: loss = F.cross_entropy(logits, label) total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux) else: loss = F.cross_entropy(logits, label) total_loss = F.cross_entropy(logits, label) tl2.add(loss) forward_tm = time.time() self.ft.add(forward_tm - data_tm) acc = count_acc(logits, label) tl1.add(total_loss.item()) ta.add(acc) self.optimizer.zero_grad() total_loss.backward() backward_tm = time.time() self.bt.add(backward_tm - forward_tm) self.optimizer.step() optimizer_tm = time.time() self.ot.add(optimizer_tm - backward_tm) # refresh start_tm start_tm = time.time() if args.debug_fast: print('Debug fast, breaking training after 1 mini-batch') break self.lr_scheduler.step() self.try_evaluate_tst(epoch) print('ETA:{}/{}'.format( self.timer.measure(), self.timer.measure(self.train_epoch / args.max_epoch)) ) torch.save(self.trlog, osp.join(args.save_path, 'trlog'))
def train(self): args = self.args # start GFSL training for epoch in range(1, args.max_epoch + 1): self.train_epoch += 1 self.model.train() tl1 = Averager() tl2 = Averager() ta = Averager() start_tm = time.time() for _, batch in enumerate( zip(self.train_fsl_loader, self.train_gfsl_loader)): self.train_step += 1 if torch.cuda.is_available(): support_data, support_label = batch[0][0].cuda( ), batch[0][1].cuda() query_data, query_label = batch[1][0].cuda( ), batch[1][1].cuda() else: support_data, support_label = batch[0][0], batch[0][1] query_data, query_label = batch[1][0], batch[1][1] data_tm = time.time() self.dt.add(data_tm - start_tm) logits = self.model(support_data, query_data, support_label) loss = F.cross_entropy( logits, query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1)) tl2.add(loss.item()) forward_tm = time.time() self.ft.add(forward_tm - data_tm) acc = count_acc( logits, query_label.view(-1, 1).repeat(1, args.num_tasks).view(-1)) tl1.add(loss.item()) ta.add(acc) self.optimizer.zero_grad() loss.backward() backward_tm = time.time() self.bt.add(backward_tm - forward_tm) self.optimizer.step() self.lr_scheduler.step() optimizer_tm = time.time() self.ot.add(optimizer_tm - backward_tm) self.try_logging(tl1, tl2, ta) # refresh start_tm start_tm = time.time() del logits, loss torch.cuda.empty_cache() self.try_evaluate(epoch) print('ETA:{}/{}'.format( self.timer.measure(), self.timer.measure(self.train_epoch / args.max_epoch))) torch.save(self.trlog, osp.join(args.save_path, 'trlog')) self.save_model('epoch-last')
def validate(args, model, val_loader, epoch, trlog=None): model.eval() global writer vl_dist, va_dist, vl_sim, va_sim = Averager(), Averager(), Averager( ), Averager() if trlog is not None: print('[Dist] best epoch {}, current best val acc={:.4f}'.format( trlog['max_acc_dist_epoch'], trlog['max_acc_dist'])) print('[Sim] best epoch {}, current best val acc={:.4f}'.format( trlog['max_acc_sim_epoch'], trlog['max_acc_sim'])) # test performance with Few-Shot label = torch.arange(args.num_val_class).repeat(args.query).long() if torch.cuda.is_available(): label = label.cuda() with torch.no_grad(): for i, batch in tqdm(enumerate(val_loader, 1), total=len(val_loader)): if torch.cuda.is_available(): data, _ = [_.cuda() for _ in batch] else: data, _ = batch data_shot, data_query = data[:args.num_val_class], data[ args.num_val_class:] # 16-way test logits_dist, logits_sim = model.forward_proto( data_shot, data_query, args.num_val_class) loss_dist = F.cross_entropy(logits_dist, label) acc_dist = count_acc(logits_dist, label) loss_sim = F.cross_entropy(logits_sim, label) acc_sim = count_acc(logits_sim, label) vl_dist.add(loss_dist.item()) va_dist.add(acc_dist) vl_sim.add(loss_sim.item()) va_sim.add(acc_sim) vl_dist = vl_dist.item() va_dist = va_dist.item() vl_sim = vl_sim.item() va_sim = va_sim.item() print( 'epoch {}, val, loss_dist={:.4f} acc_dist={:.4f} loss_sim={:.4f} acc_sim={:.4f}' .format(epoch, vl_dist, va_dist, vl_sim, va_sim)) if trlog is not None: writer.add_scalar('data/val_loss_dist', float(vl_dist), epoch) writer.add_scalar('data/val_acc_dist', float(va_dist), epoch) writer.add_scalar('data/val_loss_sim', float(vl_sim), epoch) writer.add_scalar('data/val_acc_sim', float(va_sim), epoch) if va_dist > trlog['max_acc_dist']: trlog['max_acc_dist'] = va_dist trlog['max_acc_dist_epoch'] = epoch save_model('max_acc_dist', model, args) save_checkpoint(True) if va_sim > trlog['max_acc_sim']: trlog['max_acc_sim'] = va_sim trlog['max_acc_sim_epoch'] = epoch save_model('max_acc_sim', model, args) save_checkpoint(True) trlog['val_loss_dist'].append(vl_dist) trlog['val_acc_dist'].append(va_dist) trlog['val_loss_sim'].append(vl_sim) trlog['val_acc_sim'].append(va_sim) return trlog
if torch.cuda.is_available(): data, label = [_.cuda() for _ in batch] label = label.type(torch.cuda.LongTensor) else: data, label = batch label = label.type(torch.LongTensor) logits = model(data) loss = criterion(logits, label) acc = count_acc(logits, label) writer.add_scalar('data/loss', float(loss), global_count) writer.add_scalar('data/acc', float(acc), global_count) if (i - 1) % 100 == 0: print('epoch {}, train {}/{}, loss={:.4f} acc={:.4f}'.format( epoch, i, len(train_loader), loss.item(), acc)) tl.add(loss.item()) ta.add(acc) optimizer.zero_grad() loss.backward() optimizer.step() tl = tl.item() ta = ta.item() # do not do validation in first 500 epoches if epoch > 100 or (epoch - 1) % 5 == 0: model.eval() vl_dist = Averager() va_dist = Averager() vl_sim = Averager()
def train(self): args = self.args self.model.train() if self.args.fix_BN: self.model.encoder.eval() # start FSL training label, label_aux = self.prepare_label() for epoch in range(1, args.max_epoch + 1): self.train_epoch += 1 self.model.train() if self.args.fix_BN: self.model.encoder.eval() tl1 = Averager() tl2 = Averager() ta = Averager() start_tm = time.time() for batch in tqdm(self.train_loader): data, gt_label = [_.cuda() for _ in batch] data_tm = time.time() self.dt.add(data_tm - start_tm) # get saved centers logits, reg_logits = self.para_model(data) logits = logits.view(-1, args.way) oh_query = torch.nn.functional.one_hot(label, args.way) sims = logits temp = (sims * oh_query).sum(-1) e_sim_p = temp - self.model.margin e_sim_p_pos = F.relu(e_sim_p) e_sim_p_neg = F.relu(-e_sim_p) l_open_margin = e_sim_p_pos.mean(-1) l_open = e_sim_p_neg.mean(-1) if reg_logits is not None: loss = F.cross_entropy(logits, label) total_loss = loss + args.balance * F.cross_entropy(reg_logits, label_aux) else: loss = F.cross_entropy(logits, label) total_loss = total_loss + args.open_balance * l_open tl2.add(loss) forward_tm = time.time() self.ft.add(forward_tm - data_tm) acc = count_acc(logits, label) tl1.add(total_loss.item()) ta.add(acc) self.optimizer.zero_grad() total_loss.backward(retain_graph=True) self.optimizer_margin.zero_grad() l_open_margin.backward() self.optimizer.step() self.optimizer_margin.step() backward_tm = time.time() self.bt.add(backward_tm - forward_tm) optimizer_tm = time.time() self.ot.add(optimizer_tm - backward_tm) # refresh start_tm start_tm = time.time() print('lr: {:.4f} Total_loss: {:.4f} ce_loss {:.4f} l_open: {:4f} R: {:4f}\n'.format(self.optimizer_margin.param_groups[0]['lr'],\ total_loss.item(), loss.item(), l_open.item(), self.model.margin.item())) self.lr_scheduler.step() self.lr_scheduler_margin.step() self.try_evaluate(epoch) print('ETA:{}/{}'.format( self.timer.measure(), self.timer.measure(self.train_epoch / args.max_epoch)) ) torch.save(self.trlog, osp.join(args.save_path, 'trlog')) self.save_model('epoch-last')