model = model.to(device) PAD = cfg.TRAIN.pad img_list = os.listdir(base_path) for f_img in img_list: print('Inference: ' + f_img, end=' ') # raw = np.asarray(Image.open(os.path.join(base_path, f_img)).convert('L')) raw = np.asarray(Image.open(os.path.join(base_path, f_img))) raw = raw.transpose(2, 0, 1) if cfg.TRAIN.track == 'complex': if raw.shape[0] == 9959 or raw.shape[0] == 9958: raw_ = np.zeros((10240,10240), dtype=np.uint8) raw_[141:141+9959, 141:141+9958] = raw raw = raw_ del raw_ crop_img = Crop_image(raw,crop_size=cfg.TRAIN.crop_size,overlap=cfg.TRAIN.overlap) start = time.time() for i in range(crop_img.num): for j in range(crop_img.num): raw_crop = crop_img.gen(i, j) ######### #inference if crop_img.dim == 3: raw_crop_ = raw_crop[np.newaxis, :, :, :] else: raw_crop_ = raw_crop[np.newaxis, np.newaxis, :, :] inputs = torch.Tensor(raw_crop_).to(device) inputs = F.pad(inputs, (PAD, PAD, PAD, PAD)) with torch.no_grad(): pred = model(inputs) # pred = pred[0]
def loop(cfg, train_provider, valid_provider, model, criterion, optimizer, iters, writer): PAD = cfg.TRAIN.pad f_loss_txt = open(os.path.join(cfg.record_path, 'loss.txt'), 'a') f_valid_txt = open(os.path.join(cfg.record_path, 'valid.txt'), 'a') rcd_time = [] sum_time = 0 sum_loss = 0 valid_score = [] valid_score_tmp = [] thresd = cfg.TRAIN.thresd most_f1 = 0 most_iters = 0 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') while iters <= cfg.TRAIN.total_iters: # train iters += 1 t1 = time.time() input, target = train_provider.next() # decay learning rate if cfg.TRAIN.end_lr == cfg.TRAIN.base_lr: current_lr = cfg.TRAIN.base_lr else: current_lr = calculate_lr(iters) for param_group in optimizer.param_groups: param_group['lr'] = current_lr optimizer.zero_grad() input = F.pad(input, (PAD, PAD, PAD, PAD)) pred = model(input) pred = F.pad(pred, (-PAD, -PAD, -PAD, -PAD)) # target = torch.unsqueeze(target, dim=1) # if iters == 1: # writer.add_graph(model, (input,)) loss = criterion(pred, target) loss.backward() if cfg.TRAIN.weight_decay is not None: for group in optimizer.param_groups: for param in group['params']: param.data = param.data.add(-cfg.TRAIN.weight_decay * group['lr'], param.data) optimizer.step() sum_loss += loss.item() sum_time += time.time() - t1 # log train if iters % cfg.TRAIN.display_freq == 0: rcd_time.append(sum_time) logging.info('step %d, loss = %.4f (wt: *10, lr: %.8f, et: %.2f sec, rd: %.2f min)' % (iters, sum_loss / cfg.TRAIN.display_freq * 10, current_lr, sum_time, (cfg.TRAIN.total_iters - iters) / cfg.TRAIN.display_freq * np.mean(np.asarray(rcd_time)) / 60)) writer.add_scalar('loss', sum_loss / cfg.TRAIN.display_freq * 10, iters) f_loss_txt.write('step = ' + str(iters) + ', loss = ' + str(sum_loss / cfg.TRAIN.display_freq * 10)) f_loss_txt.write('\n') f_loss_txt.flush() sys.stdout.flush() sum_time = 0 sum_loss = 0 # valid if iters % cfg.TRAIN.valid_freq == 0: input = F.pad(input, (-PAD, -PAD, -PAD, -PAD)) input0 = (np.squeeze(input[0].data.cpu().numpy()) * 255).astype(np.uint8) target = (np.squeeze(target[0].data.cpu().numpy()) * 255).astype(np.uint8) pred = np.squeeze(pred[0].data.cpu().numpy()) pred[pred>1] = 1; pred[pred<0] = 0 pred = (pred * 255).astype(np.uint8) input0 = input0[0] im_cat = np.concatenate([input0, pred, target], axis=1) Image.fromarray(im_cat).save(os.path.join(cfg.cache_path, '%06d.png' % iters)) # save if iters % cfg.TRAIN.save_freq == 0: if cfg.TRAIN.if_valid: for k in range(valid_provider.data.num): raw, label = valid_provider.data.gen(k) crop_img = Crop_image(raw,crop_size=cfg.TRAIN.crop_size,overlap=cfg.TRAIN.overlap) for i in range(crop_img.num): for j in range(crop_img.num): raw_crop = crop_img.gen(i, j) ######### #inference if crop_img.dim == 3: raw_crop_ = raw_crop[np.newaxis, :, :, :] else: raw_crop_ = raw_crop[np.newaxis, np.newaxis, :, :] inputs = torch.Tensor(raw_crop_).to(device) inputs = F.pad(inputs, (PAD, PAD, PAD, PAD)) with torch.no_grad(): pred = model(inputs) pred = F.pad(pred, (-PAD, -PAD, -PAD, -PAD)) pred = pred.data.cpu().numpy() pred = np.squeeze(pred) ######### crop_img.save(i, j, pred) results = crop_img.result() results[results<=thresd] = 0 results[results>thresd] = 1 temp_label = label.flatten() temp_result = results.flatten() f1_common = f1_score(1 - temp_label, 1 - temp_result) f1 = 0 results_img = (results * 255).astype(np.uint8) label_img = (label * 255).astype(np.uint8) im_cat_valid = np.concatenate([results_img, label_img], axis=1) if k == valid_provider.data.num - 1: Image.fromarray(im_cat_valid).save(os.path.join(cfg.cache_path, 'valid_%06d.png' % iters)) valid_score.append(f1) valid_score_tmp.append(f1_common) avg_f1_soft = sum(valid_score) / len(valid_score) avg_f1_tmp = sum(valid_score_tmp) / len(valid_score_tmp) logging.info('step %d, f1_soft = %.6f' % (iters, avg_f1_soft)) logging.info('step %d, f1_common = %.6f' % (iters, avg_f1_tmp)) writer.add_scalar('valid', avg_f1_tmp, iters) f_valid_txt.write('step %d, f1 = %.6f' % (iters, avg_f1_tmp)) f_valid_txt.write('\n') f_valid_txt.flush() sys.stdout.flush() if avg_f1_tmp > most_f1: most_f1 = avg_f1_tmp most_iters = iters states = {'current_iter': iters, 'valid_result': avg_f1_tmp, 'model_weights': model.state_dict()} torch.save(states, os.path.join(cfg.save_path, 'model.ckpt')) print('***************save modol, when f1 = %.6f and iters = %d.***************' % (most_f1, most_iters)) else: states = {'current_iter': iters, 'valid_result': None, 'model_weights': model.state_dict()} torch.save(states, os.path.join(cfg.save_path, 'model-%d.ckpt' % iters)) print('save model-%d' % iters)
def loop(cfg, train_provider, valid_provider, model, criterion, optimizer, iters, writer): PAD = cfg.TRAIN.pad f_loss_txt = open(os.path.join(cfg.record_path, 'loss.txt'), 'w') f_valid_txt = open(os.path.join(cfg.record_path, 'valid.txt'), 'w') rcd_time = [] sum_time = 0 sum_loss = 0 valid_score = [] valid_score_tmp = [] thresd = cfg.TRAIN.thresd most_f1 = 0 most_iters = 0 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') ###################################### # loss_vgg = PerceptualLoss(mode=2).to(device) model_vgg = VGG19().to(device) cuda_count = torch.cuda.device_count() if cuda_count > 1: model_vgg = nn.DataParallel(model_vgg) print('VGG build on %d GPUs ... ' % cuda_count, flush=True) else: print('VGG build on a single GPU ... ', flush=True) # cuda_count = torch.cuda.device_count() # if cuda_count > 1: # if cfg.TRAIN.batch_size % cuda_count == 0: # loss_vgg = nn.DataParallel(loss_vgg) ###################################### # mse_weight = cfg.TRAIN.mse_weight vgg_weight = cfg.TRAIN.vgg_weight ### loss function # if cfg.MODEL.loss_func_logits: # loss_function = F.binary_cross_entropy_with_logits # else: # loss_function = F.binary_cross_entropy # final_loss = 0 while iters <= cfg.TRAIN.total_iters: # train iters += 1 t1 = time.time() input, target = train_provider.next() # decay learning rate if cfg.TRAIN.end_lr == cfg.TRAIN.base_lr: current_lr = cfg.TRAIN.base_lr else: current_lr = calculate_lr(iters) for param_group in optimizer.param_groups: param_group['lr'] = current_lr optimizer.zero_grad() input = F.pad(input, (PAD, PAD, PAD, PAD)) model.train() pred = model(input) # if not cfg.MODEL.loss_func_logits: # for i in range(len(pred)): # pred[i] = torch.sigmoid(pred[i]) pred_depad = [] for p in pred: pred_depad.append(F.pad(p, (-PAD, -PAD, -PAD, -PAD))) del pred # target = torch.unsqueeze(target, dim=1) # if iters == 1: # writer.add_graph(model, (input,)) ############################## Compute Loss ######################################## # if cfg.MODEL.loss_balance_weight: # cur_weight = edge_weight(target) # writer.add_histogram('weight_edge', cur_weight.clone().cpu().data.numpy(), iters) # else: # cur_weight = None loss1, loss2, loss3 = multiloss(pred_depad, target) pred_depad_2 = pred_depad[2][:,1] pred_depad_2 = torch.unsqueeze(pred_depad_2, dim=1) pred_ = torch.cat([pred_depad_2, pred_depad_2, pred_depad_2], dim=1) target_ = torch.cat([target, target, target], dim=1) # loss_twonet = 0.25*loss1+loss2+loss3 loss_twonet = 0.75*(0.5*loss1+loss2)+loss3 # loss_vgg2 = loss_vgg(pred_, target_) out1 = model_vgg(pred_) out2 = model_vgg(target_) loss_vgg2 = vgg_loss(out1, out2) loss = loss_twonet + vgg_weight * loss_vgg2 loss.backward() if cfg.TRAIN.weight_decay is not None: for group in optimizer.param_groups: for param in group['params']: param.data = param.data.add(-cfg.TRAIN.weight_decay * group['lr'], param.data) optimizer.step() sum_loss += loss.item() sum_time += time.time() - t1 # log train if iters % cfg.TRAIN.display_freq == 0: rcd_time.append(sum_time) logging.info('step %d, loss = %.4f (wt: *10, lr: %.8f, et: %.2f sec, rd: %.2f min)' % (iters, sum_loss / cfg.TRAIN.display_freq * 10, current_lr, sum_time, (cfg.TRAIN.total_iters - iters) / cfg.TRAIN.display_freq * np.mean(np.asarray(rcd_time)) / 60)) logging.info('step %d, loss_twonet = %.4f, loss_vgg = %.4f' % (iters, loss_twonet, vgg_weight * loss_vgg2)) writer.add_scalar('loss/loss1', loss1, iters) writer.add_scalar('loss/loss2', loss2, iters) writer.add_scalar('loss/loss3', loss3, iters) writer.add_scalar('loss/loss_vgg', loss_vgg2, iters) writer.add_scalar('final_loss', sum_loss / cfg.TRAIN.display_freq * 10, iters) f_loss_txt.write('step = ' + str(iters) + ', loss = ' + str(sum_loss / cfg.TRAIN.display_freq * 10)) f_loss_txt.write('\n') f_loss_txt.flush() sys.stdout.flush() sum_time = 0 sum_loss = 0 # valid if iters % cfg.TRAIN.valid_freq == 0: input = F.pad(input, (-PAD, -PAD, -PAD, -PAD)) input0 = (np.squeeze(input[0].data.cpu().numpy()) * 255).astype(np.uint8) target = (np.squeeze(target[0].data.cpu().numpy()) * 255).astype(np.uint8) pred_show = [] for p in pred_depad: temp = np.squeeze(p[0].data.cpu().numpy()) temp[temp>1] = 1; temp[temp<0] = 0 temp = (temp * 255).astype(np.uint8) pred_show.append(temp) # input0 = input0[0] white = np.ones_like(pred_show[2][1], dtype=np.uint8) im1 = np.concatenate([input0, pred_show[0][1], pred_show[1][1]], axis=1) im2 = np.concatenate([target, pred_show[2][1], white], axis=1) im_cat = np.concatenate([im1, im2], axis=0) Image.fromarray(im_cat).save(os.path.join(cfg.cache_path, '%06d.png' % iters)) # save if iters % cfg.TRAIN.save_freq == 0: model.eval() for k in range(valid_provider.data.num): raw, label = valid_provider.data.gen(k) crop_img = Crop_image(raw,crop_size=cfg.TRAIN.crop_size,overlap=cfg.TRAIN.overlap) for i in range(crop_img.num): for j in range(crop_img.num): raw_crop = crop_img.gen(i, j) ######### #inference if crop_img.dim == 3: raw_crop_ = raw_crop[np.newaxis, :, :, :] else: raw_crop_ = raw_crop[np.newaxis, np.newaxis, :, :] inputs = torch.Tensor(raw_crop_).to(device) inputs = F.pad(inputs, (PAD, PAD, PAD, PAD)) with torch.no_grad(): pred = model(inputs) pred = pred[2] pred = F.pad(pred, (-PAD, -PAD, -PAD, -PAD)) pred = F.softmax(pred, dim=1) pred = torch.argmax(pred, dim=1).squeeze(0) pred = pred.data.cpu().numpy() # pred = np.squeeze(pred) # pred = pred[1] ######### crop_img.save(i, j, pred) results = crop_img.result() results[results<=thresd] = 0 results[results>thresd] = 1 temp_label = label.flatten() temp_result = results.flatten() f1_common = f1_score(1 - temp_label, 1 - temp_result) f1_soft = 0 results_img = (results * 255).astype(np.uint8) label_img = (label * 255).astype(np.uint8) im_cat_valid = np.concatenate([results_img, label_img], axis=1) if k == valid_provider.data.num - 1: Image.fromarray(im_cat_valid).save(os.path.join(cfg.cache_path, 'valid_%06d.png' % iters)) valid_score.append(f1_soft) valid_score_tmp.append(f1_common) avg_f1 = sum(valid_score) / len(valid_score) avg_f1_tmp = sum(valid_score_tmp) / len(valid_score_tmp) logging.info('step %d, f1_soft = %.6f' % (iters, avg_f1)) logging.info('step %d, f1_common = %.6f' % (iters, avg_f1_tmp)) writer.add_scalar('valid', avg_f1_tmp, iters) f_valid_txt.write('step %d, f1 = %.6f' % (iters, avg_f1_tmp)) f_valid_txt.write('\n') f_valid_txt.flush() sys.stdout.flush() if avg_f1_tmp > most_f1: most_f1 = avg_f1_tmp most_iters = iters states = {'current_iter': iters, 'valid_result': avg_f1_tmp, 'model_weights': model.state_dict()} torch.save(states, os.path.join(cfg.save_path, 'model.ckpt')) print('***************save modol, when f1 = %.6f and iters = %d.***************' % (most_f1, most_iters))