def main(opt): if not os.path.isdir(opt['log.exp_dir']): os.makedirs(opt['log.exp_dir']) # save opts with open(os.path.join(opt['log.exp_dir'], 'opt.json'), 'w') as f: json.dump(opt, f) f.write('\n') trace_file = os.path.join(opt['log.exp_dir'], 'trace.txt') # Postprocess arguments opt['model.x_dim'] = list(map(int, opt['model.x_dim'].split(','))) opt['log.fields'] = opt['log.fields'].split(',') np.random.seed(4321) torch.manual_seed(1234) if opt['data.cuda']: torch.cuda.manual_seed(1234) if opt['data.trainval']: data = data_utils.load(opt, ['trainval']) train_loader = data['trainval'] val_loader = None else: data = data_utils.load(opt, ['train', 'val']) train_loader = data['train'] val_loader = data['val'] model = model_utils.load(opt) if opt['data.cuda']: model.cuda() engine = Engine() meters = { 'train': { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] } } if val_loader is not None: meters['val'] = { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] } def on_start(state): if os.path.isfile(trace_file): os.remove(trace_file) state['scheduler'] = lr_scheduler.StepLR(state['optimizer'], opt['train.decay_every'], gamma=1) engine.hooks['on_start'] = on_start def on_start_epoch(state): for split, split_meters in meters.items(): for field, meter in split_meters.items(): meter.reset() state['scheduler'].step() engine.hooks['on_start_epoch'] = on_start_epoch def on_update(state): for field, meter in meters['train'].items(): meter.add(state['output'][field]) engine.hooks['on_update'] = on_update def on_end_epoch(hook_state, state): if val_loader is not None: if 'best_loss' not in hook_state: hook_state['best_loss'] = np.inf if 'wait' not in hook_state: hook_state['wait'] = 0 if val_loader is not None: model_utils.evaluate(state['model'], val_loader, meters['val'], desc="Epoch {:d} valid".format(state['epoch'])) meter_vals = log_utils.extract_meter_values(meters) print("Epoch {:02d}: {:s}".format(state['epoch'], log_utils.render_meter_values(meter_vals))) meter_vals['epoch'] = state['epoch'] with open(trace_file, 'a') as f: json.dump(meter_vals, f) f.write('\n') if val_loader is not None and False: # disable this block if meter_vals['val']['loss'] < hook_state['best_loss']: hook_state['best_loss'] = meter_vals['val']['loss'] print("==> best model (loss = {:0.6f}), saving model...".format(hook_state['best_loss'])) state['model'].cpu() torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.pt')) if opt['data.cuda']: state['model'].cuda() hook_state['wait'] = 0 else: hook_state['wait'] += 1 if hook_state['wait'] > opt['train.patience']: print("==> patience {:d} exceeded".format(opt['train.patience'])) state['stop'] = True else: state['model'].cpu() torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.pt')) if opt['data.cuda']: state['model'].cuda() engine.hooks['on_end_epoch'] = partial(on_end_epoch, { }) engine.train( model = model, loader = train_loader, optim_method = getattr(optim, opt['train.optim_method']), optim_config = { 'lr': opt['train.learning_rate'], 'weight_decay': opt['train.weight_decay'] }, max_epoch = opt['train.epochs'] )
def main(opt): if not os.path.isdir(opt['log.exp_dir']): os.makedirs(opt['log.exp_dir']) # save opts with open(os.path.join(opt['log.exp_dir'], 'opt.json'), 'w') as f: json.dump(opt, f) f.write('\n') trace_file = os.path.join(opt['log.exp_dir'], 'trace.txt') # Postprocess arguments opt['model.x_dim'] = list(map(int, opt['model.x_dim'].split(','))) opt['log.fields'] = opt['log.fields'].split(',') torch.manual_seed(1234) if opt['data.cuda']: torch.cuda.manual_seed(1234) #if opt['data.trainval']: # data = data_utils.load(opt, ['trainval']) # train_loader = data['trainval'] # val_loader = None #else: # data = data_utils.load(opt, ['train', 'val']) # train_loader = data['train'] # val_loader = data['val'] val_loader = 'HACK' # added train_way = opt['data.way'] train_shot = opt['data.shot'] train_query = opt['data.query'] train_episodes = opt['data.train_episodes'] data = PytorchBirdsDataLoader(n_episodes=train_episodes, n_way=train_way, n_query=train_query, n_support=train_shot, image_dim=(224, 224, 3)) model = model_utils.load(opt) if opt['data.cuda']: model.cuda() #model.encoder.cuda() engine = Engine() meters = { 'train': {field: tnt.meter.AverageValueMeter() for field in opt['log.fields']} } if val_loader is not None: meters['val'] = { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] } def on_start(state): if os.path.isfile(trace_file): os.remove(trace_file) state['scheduler'] = lr_scheduler.StepLR(state['optimizer'], opt['train.decay_every'], gamma=0.5) engine.hooks['on_start'] = on_start def on_start_epoch(state): state['loader'].mode = 'train' # added for split, split_meters in meters.items(): for field, meter in split_meters.items(): meter.reset() state['scheduler'].step() engine.hooks['on_start_epoch'] = on_start_epoch def on_update(state): for field, meter in meters['train'].items(): meter.add(state['output'][field]) engine.hooks['on_update'] = on_update def on_end_epoch(hook_state, state): if val_loader is not None: if 'best_loss' not in hook_state: hook_state['best_loss'] = np.inf if 'wait' not in hook_state: hook_state['wait'] = 0 if val_loader is not None: state['loader'].mode = 'val' # added model_utils.evaluate( state['model'], state['loader'], #val_loader, meters['val'], desc="Epoch {:d} valid".format(state['epoch'])) meter_vals = log_utils.extract_meter_values(meters) print("Epoch {:02d}: {:s}".format( state['epoch'], log_utils.render_meter_values(meter_vals))) meter_vals['epoch'] = state['epoch'] # this prevents cuda bugs meter_vals = { k: ({kk: float(vv) for kk, vv in v.items()} if isinstance(v, dict) else v) for k, v in meter_vals.items() } with open(trace_file, 'a') as f: json.dump(meter_vals, f) f.write('\n') if val_loader is not None: if meter_vals['val']['loss'] < hook_state['best_loss']: hook_state['best_loss'] = meter_vals['val']['loss'] print( "==> best model (loss = {:0.6f}), saving model...".format( hook_state['best_loss'])) state['model'].cpu() # used with inception #torch.save(state['model'].encoder.added_layers, os.path.join(opt['log.exp_dir'], 'best_model.t7')) torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.t7')) if opt['data.cuda']: state['model'].cuda() hook_state['wait'] = 0 else: hook_state['wait'] += 1 if hook_state['wait'] > opt['train.patience']: print("==> patience {:d} exceeded".format( opt['train.patience'])) state['stop'] = True else: state['model'].cpu() # used with inception #torch.save(state['model'].encoder.added_layers, os.path.join(opt['log.exp_dir'], 'best_model.t7')) torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.t7')) if opt['data.cuda']: state['model'].cuda() engine.hooks['on_end_epoch'] = partial(on_end_epoch, {}) engine.train( model=model, loader=data, #train_loader, optim_method=getattr(optim, opt['train.optim_method']), optim_config={ 'lr': opt['train.learning_rate'], 'weight_decay': opt['train.weight_decay'] }, max_epoch=opt['train.epochs'])
def main(opt): # 新建日志目录 if not os.path.isdir(opt['log.exp_dir']): os.makedirs(opt['log.exp_dir']) # save opts with open(os.path.join(opt['log.exp_dir'], 'opt.json'), 'w') as f: json.dump(opt, f) f.write('\n') trace_file = os.path.join(opt['log.exp_dir'], 'trace.txt') # Postprocess arguments opt['model.x_dim'] = list(map(int, opt['model.x_dim'].split(','))) opt['log.fields'] = opt['log.fields'].split(',') torch.manual_seed(1234) if opt['data.cuda']: torch.cuda.manual_seed(1234) #??? trainval是什么??? if opt['data.trainval']: data = data_utils.load(opt, ['trainval']) train_loader = data['trainval'] val_loader = None else: data = data_utils.load(opt, ['train', 'val']) train_loader = data['train'] val_loader = data['val'] model = model_utils.load(opt) #model = torch.load("results/m5_5way5shot/pre.t7") if opt['data.cuda']: model.cuda() engine = Engine() meters = { 'train': {field: tnt.meter.AverageValueMeter() for field in opt['log.fields']} } if val_loader is not None: meters['val'] = { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] } # 看名字知道功能的start函数,配置优化器 def on_start(state): if os.path.isfile(trace_file): os.remove(trace_file) state['scheduler'] = lr_scheduler.StepLR(state['optimizer'], opt['train.decay_every'], gamma=0.5) engine.hooks['on_start'] = on_start # 第一个epoch需要解决的事 def on_start_epoch(state): for split, split_meters in meters.items(): for field, meter in split_meters.items(): meter.reset() state['scheduler'].step() engine.hooks['on_start_epoch'] = on_start_epoch # 更新那个算平均的类 def on_update(state): for field, meter in meters['train'].items(): meter.add(state['output'][field]) engine.hooks['on_update'] = on_update #一个epoch结束时判断训练效果,以及是否结束训练(patience?为什么不用loss的改变?看了实际训练貌似loss变化挺大的) title = '%s, %s: %i_%iw_%is' % (opt['model.exp_name'], opt['data.dataset'], opt['data.way'], opt['data.test_way'], opt['data.test_shot']) lossPic = visual_utils.train_val_loss(title) accPic = visual_utils.train_val_acc(title) def on_end_epoch(hook_state, state): if val_loader is not None: if 'best_loss' not in hook_state: hook_state['best_loss'] = np.inf if 'wait' not in hook_state: hook_state['wait'] = 0 if val_loader is not None: model_utils.evaluate(state['model'], val_loader, meters['val'], desc="Epoch {:d} valid".format( state['epoch'])) meter_vals = log_utils.extract_meter_values(meters) lossPic(state['epoch'], meter_vals['train']['loss'], meter_vals['val']['loss']) accPic(state['epoch'], meter_vals['train']['acc'], meter_vals['val']['acc']) print("Epoch {:02d}: {:s}".format( state['epoch'], log_utils.render_meter_values(meter_vals))) meter_vals['epoch'] = state['epoch'] with open(trace_file, 'a') as f: json.dump(meter_vals, f) f.write('\n') if val_loader is not None: if meter_vals['val']['loss'] < hook_state['best_loss']: hook_state['best_loss'] = meter_vals['val']['loss'] print( "==> best model (loss = {:0.6f}), saving model...".format( hook_state['best_loss'])) state['model'].cpu() torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.t7')) if opt['data.cuda']: state['model'].cuda() hook_state['wait'] = 0 else: hook_state['wait'] += 1 if hook_state['wait'] > opt['train.patience']: print("==> patience {:d} exceeded".format( opt['train.patience'])) state['stop'] = True else: state['model'].cpu() torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.t7')) if opt['data.cuda']: state['model'].cuda() engine.hooks['on_end_epoch'] = partial(on_end_epoch, {}) engine.train(model=model, loader=train_loader, optim_method=getattr(optim, opt['train.optim_method']), optim_config={ 'lr': opt['train.learning_rate'], 'weight_decay': opt['train.weight_decay'] }, max_epoch=opt['train.epochs'])
def main(opt): if not os.path.isdir(opt['log.exp_dir']): os.makedirs(opt['log.exp_dir']) # save opts # 将opts加入文件中 with open(os.path.join(opt['log.exp_dir'], 'opt.json'), 'w') as f: json.dump(opt, f) f.write('\n') trace_file = os.path.join(opt['log.exp_dir'], 'trace.txt') # Postprocess arguments opt['model.x_dim'] = list(map(int, opt['model.x_dim'].split(','))) opt['log.fields'] = opt['log.fields'].split(',') torch.manual_seed(1234) if opt['data.cuda']: torch.cuda.manual_seed(1234) # 加载数据 if opt['data.trainval']: # load Omniglot dataset data = data_utils.load(opt, ['trainval']) train_loader = data['trainval'] val_loader = None else: data = data_utils.load(opt, ['train', 'val']) train_loader = data['train'] val_loader = data['val'] # 使用模型相关参数加载模型 model = model_utils.load(opt) if opt['data.cuda']: model.cuda() engine = Engine() # torchnet.meter评估方法性能,这里用的平均值 # 先建立一个每个指标及其对应评价值的字典 meters = { 'train': { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] } } if val_loader is not None: meters['val'] = { field: tnt.meter.AverageValueMeter() for field in opt['log.fields'] } def on_start(state): if os.path.isfile(trace_file): os.remove(trace_file) # 定义学习率衰减机制 state['scheduler'] = lr_scheduler.StepLR(state['optimizer'], opt['train.decay_every'], gamma=0.5) # 将上面的函数赋给匿名函数engine.hooks engine.hooks['on_start'] = on_start # 每个epoch开始时 def on_start_epoch(state): for split, split_meters in meters.items(): for field, meter in split_meters.items(): # 重置评价指标 meter.reset() # 调用optimizer的step()函数进行回传 state['scheduler'].step() # 用匿名函数包装 engine.hooks['on_start_epoch'] = on_start_epoch # 更新评价指标 def on_update(state): # 对于所有训练指标,更新每个训练指标的值 for field, meter in meters['train'].items(): meter.add(state['output'][field]) engine.hooks['on_update'] = on_update # 在每个epoch结束时,给出这个epoch的评价值 def on_end_epoch(hook_state, state): if val_loader is not None: if 'best_loss' not in hook_state: hook_state['best_loss'] = np.inf if 'wait' not in hook_state: hook_state['wait'] = 0 if val_loader is not None: # 使用val评价模型 model_utils.evaluate(state['model'], val_loader, meters['val'], desc="Epoch {:d} valid".format(state['epoch'])) meter_vals = log_utils.extract_meter_values(meters) print("Epoch {:02d}: {:s}".format(state['epoch'], log_utils.render_meter_values(meter_vals))) meter_vals['epoch'] = state['epoch'] with open(trace_file, 'a') as f: json.dump(meter_vals, f) f.write('\n') # 更新目前最好的loss if val_loader is not None: if meter_vals['val']['loss'] < hook_state['best_loss']: hook_state['best_loss'] = meter_vals['val']['loss'] print("==> best model (loss = {:0.6f}), saving model...".format(hook_state['best_loss'])) state['model'].cpu() torch.save(state['model'], os.path.join(opt['log.exp_dir'], 'best_model.pt')) if opt['data.cuda']: state['model'].cuda() hook_state['wait'] = 0 # early stop else: hook_state['wait'] += 1 if hook_state['wait'] > opt['train.patience']: print("==> patience {:d} exceeded".format(opt['train.patience'])) state['stop'] = True