def run_test(): model = getattr(models, opt.model)(use_imp=opt.use_imp, n=opt.feat_num) global use_data_parallel if opt.use_gpu: # model.cuda() # ???? model.cuda() or model = model.cuda() all is OK model, use_data_parallel = multiple_gpu_process(model) normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) test_batch_size = 1 test_data_transforms_crops = transforms.Compose([ transforms.CenterCrop(128), # center crop 128x128 transforms.ToTensor(), normalize # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_data_transforms_full = transforms.Compose([ transforms.ToTensor(), normalize # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_data_transforms = test_data_transforms_full if test_batch_size == 1 else test_data_transforms_crops # test_data_transforms = test_data_transforms_crops test_ckpt = opt.resume load_epoch = (model.module if use_data_parallel == 1 else model).load( None, test_ckpt) print('Load epoch %d for test!' % load_epoch) test_data = ImageFilelist(flist=opt.test_data_list, transform=test_data_transforms) test_dataloader = DataLoader(test_data, test_batch_size, shuffle=False) mse_loss = t.nn.MSELoss(size_average=False) if opt.use_imp: rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight) else: rate_loss = None test(model, test_dataloader, test_batch_size, mse_loss, rate_loss)
def train(**kwargs): opt.parse(kwargs) # log file logfile_name = "Cmpr_with_YOLOv2_" + opt.exp_desc + time.strftime( "_%m_%d_%H:%M:%S") + ".log.txt" ps = PlotSaver(logfile_name) # step1: Model model = getattr(models, opt.model)( use_imp=opt.use_imp, n=opt.feat_num, input_4_ch=opt.input_4_ch, model_name="Cmpr_yolo_imp_" + opt.exp_desc + "_r={r}_gama={w}".format( r=opt.rate_loss_threshold, w=opt.rate_loss_weight) if opt.use_imp else "Cmpr_yolo_no_imp_" + opt.exp_desc) # pdb.set_trace() if opt.use_gpu: model = multiple_gpu_process(model) cudnn.benchmark = True # step2: Data normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_data_transforms = transforms.Compose([ # transforms.RandomHorizontalFlip(), TODO: try to reimplement by myself to simultaneous operate on label and data transforms.ToTensor(), normalize ]) val_data_transforms = transforms.Compose( [transforms.ToTensor(), normalize]) train_data = ImageCropWithBBoxMaskDataset( opt.train_data_list, train_data_transforms, contrastive_degree=opt.contrastive_degree, mse_bbox_weight=opt.input_original_bbox_weight) val_data = ImageCropWithBBoxMaskDataset( opt.val_data_list, val_data_transforms, contrastive_degree=opt.contrastive_degree, mse_bbox_weight=opt.input_original_bbox_weight) train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) val_dataloader = DataLoader(val_data, opt.batch_size, shuffle=False, num_workers=opt.num_workers, pin_memory=True) # step3: criterion and optimizer mse_loss = t.nn.MSELoss(size_average=False) if opt.use_imp: # TODO: new rate loss rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight) # rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight) def weighted_mse_loss(input, target, weight): # weight[weight!=opt.mse_bbox_weight] = 1 # weight[weight==opt.mse_bbox_weight] = opt.mse_bbox_weight # print('max val', weight.max()) # return mse_loss(input, target) # weight_clone = weight.clone() # weight_clone[weight_clone == opt.input_original_bbox_weight] = 0 # return t.sum(weight_clone * (input - target) ** 2) weight_clone = t.ones_like(weight) weight_clone[weight == opt.input_original_bbox_inner] = opt.mse_bbox_weight return t.sum(weight_clone * (input - target)**2) def yolo_rate_loss(imp_map, mask_r): return rate_loss(imp_map) # V2 contrastive_degree must be 0! # return YoloRateLossV2(mask_r, opt.rate_loss_threshold, opt.rate_loss_weight)(imp_map) lr = opt.lr optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999)) start_epoch = 0 decay_file_create_time = -1 # 为了避免同一个文件反复衰减学习率, 所以判断修改时间 if opt.resume: if use_data_parallel: start_epoch = model.module.load( None if opt.finetune else optimizer, opt.resume, opt.finetune) else: start_epoch = model.load(None if opt.finetune else optimizer, opt.resume, opt.finetune) if opt.finetune: print('Finetune from model checkpoint file', opt.resume) else: print('Resume training from checkpoint file', opt.resume) print('Continue training at epoch %d.' % start_epoch) # step4: meters mse_loss_meter = AverageValueMeter() if opt.use_imp: rate_loss_meter = AverageValueMeter() rate_display_meter = AverageValueMeter() total_loss_meter = AverageValueMeter() previous_loss = 1e100 tolerant_now = 0 same_lr_epoch = 0 # ps init ps.new_plot('train mse loss', opt.print_freq, xlabel="iteration", ylabel="train_mse_loss") ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss") if opt.use_imp: ps.new_plot('train rate value', opt.print_freq, xlabel="iteration", ylabel="train_rate_value") ps.new_plot('train rate loss', opt.print_freq, xlabel="iteration", ylabel="train_rate_loss") ps.new_plot('train total loss', opt.print_freq, xlabel="iteration", ylabel="train_total_loss") ps.new_plot('val rate value', 1, xlabel="iteration", ylabel="val_rate_value") ps.new_plot('val rate loss', 1, xlabel="iteration", ylabel="val_rate_loss") ps.new_plot('val total loss', 1, xlabel="iteration", ylabel="val_total_loss") for epoch in range(start_epoch + 1, opt.max_epoch + 1): same_lr_epoch += 1 # per epoch avg loss meter mse_loss_meter.reset() if opt.use_imp: rate_display_meter.reset() rate_loss_meter.reset() total_loss_meter.reset() else: total_loss_meter = mse_loss_meter # cur_epoch_loss refresh every epoch # vis.refresh_plot('cur epoch train mse loss') ps.new_plot("cur epoch train mse loss", opt.print_freq, xlabel="iteration in cur epoch", ylabel="train_mse_loss") # progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), ascii=True) # progress_bar.set_description('epoch %d/%d, loss = 0.00' % (epoch, opt.max_epoch)) # Init val if (epoch == start_epoch + 1) and opt.init_val: print('Init validation ... ') if opt.use_imp: mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val( model, val_dataloader, weighted_mse_loss, yolo_rate_loss, ps) else: mse_val_loss = val(model, val_dataloader, weighted_mse_loss, None, ps) ps.add_point('val mse loss', mse_val_loss) if opt.use_imp: ps.add_point('val rate value', rate_val_display) ps.add_point('val rate loss', rate_val_loss) ps.add_point('val total loss', total_val_loss) ps.make_plot('val mse loss') if opt.use_imp: ps.make_plot('val rate value') ps.make_plot('val rate loss') ps.make_plot('val total loss') # log sth. if opt.use_imp: ps.log( 'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} ' .format(epoch=epoch, lr=lr, val_mse_loss=mse_val_loss, val_rate_loss=rate_val_loss, val_total_loss=total_val_loss, val_rate_display=rate_val_display)) else: ps.log( 'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss:{val_mse_loss}' .format(epoch=epoch, lr=lr, val_mse_loss=mse_val_loss)) if opt.only_init_val: print('Only Init Val Over!') return model.train() if epoch == start_epoch + 1: print('Start training, please inspect log file %s!' % logfile_name) # mask is the detection bounding box mask for idx, (data, mask, o_mask) in enumerate(train_dataloader): # pdb.set_trace() data = Variable(data) mask = Variable(mask) o_mask = Variable(o_mask, requires_grad=False) if opt.use_gpu: data = data.cuda(async=True) mask = mask.cuda(async=True) o_mask = o_mask.cuda(async=True) # pdb.set_trace() optimizer.zero_grad() reconstructed, imp_mask_sigmoid = model(data, mask, o_mask) # print ('imp_mask_height', model.imp_mask_height) # pdb.set_trace() # print ('type recons', type(reconstructed.data)) loss = weighted_mse_loss(reconstructed, data, o_mask) # loss = mse_loss(reconstructed, data) caffe_loss = loss / (2 * opt.batch_size) if opt.use_imp: rate_loss_display = imp_mask_sigmoid # rate_loss_ = rate_loss(rate_loss_display) rate_loss_ = yolo_rate_loss(rate_loss_display, mask) total_loss = caffe_loss + rate_loss_ else: total_loss = caffe_loss total_loss.backward() optimizer.step() mse_loss_meter.add(caffe_loss.data[0]) if opt.use_imp: rate_loss_meter.add(rate_loss_.data[0]) rate_display_meter.add(rate_loss_display.data.mean()) total_loss_meter.add(total_loss.data[0]) if idx % opt.print_freq == opt.print_freq - 1: ps.add_point( 'train mse loss', mse_loss_meter.value()[0] if opt.print_smooth else caffe_loss.data[0]) ps.add_point( 'cur epoch train mse loss', mse_loss_meter.value()[0] if opt.print_smooth else caffe_loss.data[0]) if opt.use_imp: ps.add_point( 'train rate value', rate_display_meter.value()[0] if opt.print_smooth else rate_loss_display.data.mean()) ps.add_point( 'train rate loss', rate_loss_meter.value()[0] if opt.print_smooth else rate_loss_.data[0]) ps.add_point( 'train total loss', total_loss_meter.value()[0] if opt.print_smooth else total_loss.data[0]) if not opt.use_imp: ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), total_loss_meter.value()[0], lr)) else: ps.log( 'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), total_loss_meter.value()[0], mse_loss_meter.value()[0], rate_loss_meter.value()[0], rate_display_meter.value()[0], lr)) # 进入debug模式 if os.path.exists(opt.debug_file): pdb.set_trace() if epoch % opt.save_interval == 0: print('save checkpoint file of epoch %d.' % epoch) if use_data_parallel: model.module.save(optimizer, epoch) else: model.save(optimizer, epoch) ps.make_plot('train mse loss') ps.make_plot('cur epoch train mse loss', epoch) if opt.use_imp: ps.make_plot("train rate value") ps.make_plot("train rate loss") ps.make_plot("train total loss") if epoch % opt.eval_interval == 0: print('Validating ...') # val if opt.use_imp: mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val( model, val_dataloader, weighted_mse_loss, yolo_rate_loss, ps) else: mse_val_loss = val(model, val_dataloader, weighted_mse_loss, None, ps) ps.add_point('val mse loss', mse_val_loss) if opt.use_imp: ps.add_point('val rate value', rate_val_display) ps.add_point('val rate loss', rate_val_loss) ps.add_point('val total loss', total_val_loss) ps.make_plot('val mse loss') if opt.use_imp: ps.make_plot('val rate value') ps.make_plot('val rate loss') ps.make_plot('val total loss') # log sth. if opt.use_imp: ps.log( 'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\ val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} ' .format(epoch=epoch, lr=lr, train_mse_loss=mse_loss_meter.value()[0], train_rate_loss=rate_loss_meter.value()[0], train_total_loss=total_loss_meter.value()[0], train_rate_display=rate_display_meter.value()[0], val_mse_loss=mse_val_loss, val_rate_loss=rate_val_loss, val_total_loss=total_val_loss, val_rate_display=rate_val_display)) else: ps.log( 'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}' .format(epoch=epoch, lr=lr, train_mse_loss=mse_loss_meter.value()[0], val_mse_loss=mse_val_loss)) # Adaptive adjust lr # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高, # update learning rate # if loss_meter.value()[0] > previous_loss: if opt.use_early_adjust: if total_loss_meter.value()[0] > previous_loss: tolerant_now += 1 if tolerant_now == opt.tolerant_max: tolerant_now = 0 same_lr_epoch = 0 lr = lr * opt.lr_decay for param_group in optimizer.param_groups: param_group['lr'] = lr print('Due to early stop anneal lr to %.10f at epoch %d' % (lr, epoch)) ps.log('Due to early stop anneal lr to %.10f at epoch %d' % (lr, epoch)) else: tolerant_now -= 1 if epoch % opt.lr_anneal_epochs == 0: # if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0: same_lr_epoch = 0 tolerant_now = 0 lr = lr * opt.lr_decay for param_group in optimizer.param_groups: param_group['lr'] = lr print('Anneal lr to %.10f at epoch %d due to full epochs.' % (lr, epoch)) ps.log('Anneal lr to %.10f at epoch %d due to full epochs.' % (lr, epoch)) if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file): cur_mtime = os.path.getmtime(opt.lr_decay_file) if cur_mtime > decay_file_create_time: decay_file_create_time = cur_mtime lr = lr * opt.lr_decay for param_group in optimizer.param_groups: param_group['lr'] = lr print( 'Anneal lr to %.10f at epoch %d due to decay-file indicator.' % (lr, epoch)) ps.log( 'Anneal lr to %.10f at epoch %d due to decay-file indicator.' % (lr, epoch)) previous_loss = total_loss_meter.value()[0]
def train(**kwargs): global batch_model_id opt.parse(kwargs) if opt.use_batch_process: max_bp_times = len(opt.exp_ids) if batch_model_id >= max_bp_times: print('Batch Processing Done!') return else: print('Cur Batch Processing ID is %d/%d.' % (batch_model_id + 1, max_bp_times)) opt.r = opt.r_s[batch_model_id] opt.exp_id = opt.exp_ids[batch_model_id] opt.exp_desc = opt.exp_desc_LUT[opt.exp_id] opt.plot_path = "plot/plot_%d" % opt.exp_id print('Cur Model(exp_%d) r = %f, desc = %s. ' % (opt.exp_id, opt.r, opt.exp_desc)) opt.make_new_dirs() # log file EvalVal = opt.only_init_val and opt.init_val and not opt.test_test EvalTest = opt.only_init_val and opt.init_val and opt.test_test EvalSuffix = "" if EvalVal: EvalSuffix = "_val" if EvalTest: EvalSuffix = "_test" logfile_name = opt.exp_desc + time.strftime( "_%m_%d_%H:%M:%S") + EvalSuffix + ".log.txt" ps = PlotSaver(logfile_name, log_to_stdout=opt.log_to_stdout) # step1: Model # model = getattr(models, opt.model)(use_imp = opt.use_imp, n = opt.feat_num, model_name=opt.exp_desc + ("_r={r}_gm={w}".format( # r=opt.rate_loss_threshold, # w=opt.rate_loss_weight) # if opt.use_imp else "_no_imp")) model = getattr(models, opt.model)(use_imp=opt.use_imp, n=opt.feat_num, model_name=opt.exp_desc) # print (model) # pdb.set_trace() global use_data_parallel if opt.use_gpu: model, use_data_parallel = multiple_gpu_process(model) # real use gpu or cpu opt.use_gpu = opt.use_gpu and use_data_parallel >= 0 cudnn.benchmark = True # step2: Data normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_data_transforms = transforms.Compose([ # transforms.RandomHorizontalFlip(), TODO: try to reimplement by myself to simultaneous operate on label and data transforms.RandomCrop(128), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) val_data_transforms = transforms.Compose( [transforms.CenterCrop(128), transforms.ToTensor(), normalize]) caffe_data_transforms = transforms.Compose([transforms.CenterCrop(128)]) # transforms || data train_data = ImageFilelist( flist=opt.train_data_list, transform=train_data_transforms, ) val_data = ImageFilelist( flist=opt.val_data_list, prefix=opt.val_data_prefix, transform=val_data_transforms, ) val_data_caffe = ImageFilelist(flist=opt.val_data_list, transform=caffe_data_transforms) test_data_caffe = ImageFilelist(flist=opt.test_data_list, transform=caffe_data_transforms) if opt.make_caffe_data: save_caffe_data(test_data_caffe) print('Make caffe dataset over!') return # train_data = ImageCropWithBBoxMaskDataset( # opt.train_data_list, # train_data_transforms, # contrastive_degree = opt.contrastive_degree, # mse_bbox_weight = opt.input_original_bbox_weight # ) # val_data = ImageCropWithBBoxMaskDataset( # opt.val_data_list, # val_data_transforms, # contrastive_degree = opt.contrastive_degree, # mse_bbox_weight = opt.input_original_bbox_weight # ) train_dataloader = DataLoader(train_data, opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=True) val_dataloader = DataLoader(val_data, opt.batch_size, shuffle=False, num_workers=opt.num_workers, pin_memory=True) # step3: criterion and optimizer mse_loss = t.nn.MSELoss(size_average=False) if opt.use_imp: # TODO: new rate loss rate_loss = RateLoss(opt.rate_loss_threshold, opt.rate_loss_weight) # rate_loss = LimuRateLoss(opt.rate_loss_threshold, opt.rate_loss_weight) def weighted_mse_loss(input, target, weight): # weight[weight!=opt.mse_bbox_weight] = 1 # weight[weight==opt.mse_bbox_weight] = opt.mse_bbox_weight # print('max val', weight.max()) # return mse_loss(input, target) # weight_clone = weight.clone() # weight_clone[weight_clone == opt.input_original_bbox_weight] = 0 # return t.sum(weight_clone * (input - target) ** 2) weight_clone = t.ones_like(weight) weight_clone[weight == opt.input_original_bbox_inner] = opt.mse_bbox_weight return t.sum(weight_clone * (input - target)**2) def yolo_rate_loss(imp_map, mask_r): return rate_loss(imp_map) # V2 contrastive_degree must be 0! # return YoloRateLossV2(mask_r, opt.rate_loss_threshold, opt.rate_loss_weight)(imp_map) start_epoch = 0 decay_file_create_time = -1 # 为了避免同一个文件反复衰减学习率, 所以判断修改时间 previous_loss = 1e100 tolerant_now = 0 same_lr_epoch = 0 lr = opt.lr optimizer = t.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999)) if opt.resume: start_epoch = (model.module if use_data_parallel == 1 else model).load( None if opt.finetune else optimizer, opt.resume, opt.finetune) if opt.finetune: print('Finetune from model checkpoint file', opt.resume) else: print('Resume training from checkpoint file', opt.resume) print('Continue training from epoch %d.' % start_epoch) same_lr_epoch = start_epoch % opt.lr_anneal_epochs decay_times = start_epoch // opt.lr_anneal_epochs lr = opt.lr * (opt.lr_decay**decay_times) for param_group in optimizer.param_groups: param_group['lr'] = lr print('Decay lr %d times, now lr is %e.' % (decay_times, lr)) # step4: meters mse_loss_meter = AverageValueMeter() if opt.use_imp: rate_loss_meter = AverageValueMeter() rate_display_meter = AverageValueMeter() total_loss_meter = AverageValueMeter() # ps init ps.new_plot('train mse loss', opt.print_freq, xlabel="iteration", ylabel="train_mse_loss") ps.new_plot('val mse loss', 1, xlabel="epoch", ylabel="val_mse_loss") if opt.use_imp: ps.new_plot('train rate value', opt.print_freq, xlabel="iteration", ylabel="train_rate_value") ps.new_plot('train rate loss', opt.print_freq, xlabel="iteration", ylabel="train_rate_loss") ps.new_plot('train total loss', opt.print_freq, xlabel="iteration", ylabel="train_total_loss") ps.new_plot('val rate value', 1, xlabel="iteration", ylabel="val_rate_value") ps.new_plot('val rate loss', 1, xlabel="iteration", ylabel="val_rate_loss") ps.new_plot('val total loss', 1, xlabel="iteration", ylabel="val_total_loss") # 如果测试时是600,max_epoch也是600 if opt.only_init_val and opt.max_epoch <= start_epoch: opt.max_epoch = start_epoch + 2 for epoch in range(start_epoch + 1, opt.max_epoch + 1): same_lr_epoch += 1 # per epoch avg loss meter mse_loss_meter.reset() if opt.use_imp: rate_display_meter.reset() rate_loss_meter.reset() total_loss_meter.reset() else: total_loss_meter = mse_loss_meter # cur_epoch_loss refresh every epoch # vis.refresh_plot('cur epoch train mse loss') ps.new_plot("cur epoch train mse loss", opt.print_freq, xlabel="iteration in cur epoch", ylabel="train_mse_loss") # progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), ascii=True) # progress_bar.set_description('epoch %d/%d, loss = 0.00' % (epoch, opt.max_epoch)) # Init val if (epoch == start_epoch + 1) and opt.init_val: print('Init validation ... ') if opt.use_imp: mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val( model, val_dataloader, mse_loss, rate_loss, ps) else: mse_val_loss = val(model, val_dataloader, mse_loss, None, ps) ps.add_point('val mse loss', mse_val_loss) if opt.use_imp: ps.add_point('val rate value', rate_val_display) ps.add_point('val rate loss', rate_val_loss) ps.add_point('val total loss', total_val_loss) ps.make_plot('val mse loss') if opt.use_imp: ps.make_plot('val rate value') ps.make_plot('val rate loss') ps.make_plot('val total loss') # log sth. if opt.use_imp: ps.log( 'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} ' .format(epoch=epoch, lr=lr, val_mse_loss=mse_val_loss, val_rate_loss=rate_val_loss, val_total_loss=total_val_loss, val_rate_display=rate_val_display)) else: ps.log( 'Init Val @ Epoch:{epoch}, lr:{lr}, val_mse_loss: {val_mse_loss}' .format(epoch=epoch, lr=lr, val_mse_loss=mse_val_loss)) if opt.only_init_val: print('Only Init Val Over!') return model.train() # if epoch == start_epoch + 1: # print ('Start training, please inspect log file %s!' % logfile_name) # mask is the detection bounding box mask for idx, data in enumerate(train_dataloader): # pdb.set_trace() data = Variable(data) # mask = Variable(mask) # o_mask = Variable(o_mask, requires_grad=False) if opt.use_gpu: data = data.cuda(async=True) # mask = mask.cuda(async = True) # o_mask = o_mask.cuda(async = True) # pdb.set_trace() optimizer.zero_grad() if opt.use_imp: reconstructed, imp_mask_sigmoid = model(data) else: reconstructed = model(data) # print ('imp_mask_height', model.imp_mask_height) # pdb.set_trace() # print ('type recons', type(reconstructed.data)) loss = mse_loss(reconstructed, data) # loss = mse_loss(reconstructed, data) caffe_loss = loss / (2 * opt.batch_size) if opt.use_imp: rate_loss_ = rate_loss(imp_mask_sigmoid) # rate_loss_ = yolo_rate_loss(imp_mask_sigmoid, mask) total_loss = caffe_loss + rate_loss_ else: total_loss = caffe_loss total_loss.backward() optimizer.step() mse_loss_meter.add(caffe_loss.data[0]) if opt.use_imp: rate_loss_meter.add(rate_loss_.data[0]) rate_display_meter.add(imp_mask_sigmoid.data.mean()) total_loss_meter.add(total_loss.data[0]) if idx % opt.print_freq == opt.print_freq - 1: ps.add_point( 'train mse loss', mse_loss_meter.value()[0] if opt.print_smooth else caffe_loss.data[0]) ps.add_point( 'cur epoch train mse loss', mse_loss_meter.value()[0] if opt.print_smooth else caffe_loss.data[0]) if opt.use_imp: ps.add_point( 'train rate value', rate_display_meter.value()[0] if opt.print_smooth else imp_mask_sigmoid.data.mean()) ps.add_point( 'train rate loss', rate_loss_meter.value()[0] if opt.print_smooth else rate_loss_.data[0]) ps.add_point( 'train total loss', total_loss_meter.value()[0] if opt.print_smooth else total_loss.data[0]) if not opt.use_imp: ps.log('Epoch %d/%d, Iter %d/%d, loss = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), total_loss_meter.value()[0], lr)) else: ps.log( 'Epoch %d/%d, Iter %d/%d, loss = %.2f, mse_loss = %.2f, rate_loss = %.2f, rate_display = %.2f, lr = %.8f' % (epoch, opt.max_epoch, idx, len(train_dataloader), total_loss_meter.value()[0], mse_loss_meter.value()[0], rate_loss_meter.value()[0], rate_display_meter.value()[0], lr)) # 进入debug模式 if os.path.exists(opt.debug_file): pdb.set_trace() if epoch % opt.save_interval == 0: print('Save checkpoint file of epoch %d.' % epoch) if use_data_parallel == 1: model.module.save(optimizer, epoch) else: model.save(optimizer, epoch) ps.make_plot('train mse loss') ps.make_plot('cur epoch train mse loss', epoch) if opt.use_imp: ps.make_plot("train rate value") ps.make_plot("train rate loss") ps.make_plot("train total loss") if epoch % opt.eval_interval == 0: print('Validating ...') # val if opt.use_imp: mse_val_loss, rate_val_loss, total_val_loss, rate_val_display = val( model, val_dataloader, mse_loss, rate_loss, ps) else: mse_val_loss = val(model, val_dataloader, mse_loss, None, ps) ps.add_point('val mse loss', mse_val_loss) if opt.use_imp: ps.add_point('val rate value', rate_val_display) ps.add_point('val rate loss', rate_val_loss) ps.add_point('val total loss', total_val_loss) ps.make_plot('val mse loss') if opt.use_imp: ps.make_plot('val rate value') ps.make_plot('val rate loss') ps.make_plot('val total loss') # log sth. if opt.use_imp: ps.log( 'Epoch:{epoch}, lr:{lr}, train_mse_loss: {train_mse_loss}, train_rate_loss: {train_rate_loss}, train_total_loss: {train_total_loss}, train_rate_display: {train_rate_display} \n\ val_mse_loss: {val_mse_loss}, val_rate_loss: {val_rate_loss}, val_total_loss: {val_total_loss}, val_rate_display: {val_rate_display} ' .format(epoch=epoch, lr=lr, train_mse_loss=mse_loss_meter.value()[0], train_rate_loss=rate_loss_meter.value()[0], train_total_loss=total_loss_meter.value()[0], train_rate_display=rate_display_meter.value()[0], val_mse_loss=mse_val_loss, val_rate_loss=rate_val_loss, val_total_loss=total_val_loss, val_rate_display=rate_val_display)) else: ps.log( 'Epoch:{epoch}, lr:{lr}, train_mse_loss:{train_mse_loss}, val_mse_loss:{val_mse_loss}' .format(epoch=epoch, lr=lr, train_mse_loss=mse_loss_meter.value()[0], val_mse_loss=mse_val_loss)) # Adaptive adjust lr # 每个lr,如果有opt.tolerant_max次比上次的val_loss还高, # update learning rate # if loss_meter.value()[0] > previous_loss: if opt.use_early_adjust: if total_loss_meter.value()[0] > previous_loss: tolerant_now += 1 if tolerant_now == opt.tolerant_max: tolerant_now = 0 same_lr_epoch = 0 lr = lr * opt.lr_decay for param_group in optimizer.param_groups: param_group['lr'] = lr print('Due to early stop anneal lr to %.10f at epoch %d' % (lr, epoch)) ps.log('Due to early stop anneal lr to %.10f at epoch %d' % (lr, epoch)) else: # tolerant_now -= 1 tolerant_now = 0 if epoch % opt.lr_anneal_epochs == 0: # if same_lr_epoch and same_lr_epoch % opt.lr_anneal_epochs == 0: same_lr_epoch = 0 tolerant_now = 0 lr = lr * opt.lr_decay for param_group in optimizer.param_groups: param_group['lr'] = lr print('Anneal lr to %.10f at epoch %d due to full epochs.' % (lr, epoch)) ps.log('Anneal lr to %.10f at epoch %d due to full epochs.' % (lr, epoch)) if opt.use_file_decay_lr and os.path.exists(opt.lr_decay_file): cur_mtime = os.path.getmtime(opt.lr_decay_file) if cur_mtime > decay_file_create_time: decay_file_create_time = cur_mtime lr = lr * opt.lr_decay for param_group in optimizer.param_groups: param_group['lr'] = lr print( 'Anneal lr to %.10f at epoch %d due to decay-file indicator.' % (lr, epoch)) ps.log( 'Anneal lr to %.10f at epoch %d due to decay-file indicator.' % (lr, epoch)) previous_loss = total_loss_meter.value()[0] if opt.use_batch_process: batch_model_id += 1 train(kwargs)