import tensorflow as tf from parse_opt import get_arguments from deeplab_resnet import HyperColumn_Deeplabv2, read_data_list import os import numpy as np import torch import torch.utils.data as data import main_hyper from threading import Lock shape = 250 args = get_arguments() config = tf.ConfigProto(device_count={'GPU': 0}) sess = tf.Session(config=config) model = HyperColumn_Deeplabv2(sess, args) model.load(args.snapshot_dir) def np2Tensor(array): ts = (2, 0, 1) tmp = array.copy() tensor = torch.FloatTensor(tmp.transpose(ts).astype(float)) return tensor class InputData(data.Dataset): def __init__(self, data_folder, padsize=50): self.local_imgflist = main_hyper.load_dir_structs(data_folder) self.padsize = padsize # config.gpu_options.allow_growth = True
def main(): # load train args args = get_arguments() if args.without_gpu: device = torch.device('cpu') else: if torch.cuda.is_available(): device = torch.device('cuda') else: print('No gpu available') return model = torch_net.FCN() model.to(device) train_data = DataLoader(dataset.InputData(data_folder=args.dataDir), batch_size=args.train_batch, drop_last=True, shuffle=True, num_workers=0, pin_memory=True) # 只能将num_workers设为0, 表示只用主线程来load valid_data = DataLoader(dataset.InputData(data_folder=args.validDir), batch_size=1, drop_last=False, shuffle=True, num_workers=0) model.train() lr = args.lr trainlog = TrainLog(args) start_epoch = 1 if args.finetuning: start_epoch, model = trainlog.load_model(model) optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, betas=(0.9, 0.999), weight_decay=0.0005) criterion = nn.L1Loss() iou_threshold = [0.2, 0.4, 0.6, 0.8, 0.9] for epoch in range(start_epoch, start_epoch + num_epochs): loss_ = 0. iou_ = [0.] * len(iou_threshold) for i, batch in enumerate(train_data): feat, img, alpha_gt = batch feat, img, alpha_gt = feat.to(device), img.to(device), alpha_gt.to( device) alpha = model(feat, img) loss = criterion(alpha, alpha_gt) iou = [IOU(alpha_gt, alpha, t) for t in iou_threshold] optimizer.zero_grad() loss.backward() optimizer.step() loss_ += loss.item() for j in range(len(iou_threshold)): iou_[j] += iou[j] print('loss:', loss_ / (1 + i)) trainlog.add_scalar('L1', loss.item()) for j in range(len(iou_threshold)): trainlog.add_scalar('IOU_' + str(iou_threshold[j]), iou[j]) trainlog.add_image( 'alpha_gt', vutils.make_grid(alpha_gt, normalize=False, nrow=3)) trainlog.add_image( 'alpha', vutils.make_grid(alpha, normalize=False, nrow=3)) trainlog.add_image( 'alpha_diff', vutils.make_grid(alpha - alpha_gt, normalize=False, nrow=3)) if (i + 1) % 20 == 0: for var_name, value in model.named_parameters(): # 万一网络中有不在本次训练iterate的参数 if not hasattr(value.grad, 'data'): continue var_name = var_name.replace('.', '/') trainlog.add_histogram(var_name, value.data.cpu().numpy()) trainlog.add_histogram(var_name + '/grad', value.grad.data.cpu().numpy()) trainlog.next_step() print('Epoch %d, avg loss: %.3f' % (epoch, loss_ / (i + 1))) trainlog.save_model(model, epoch) for j in range(len(iou_threshold)): trainlog.add_scalar('avg_IOU_' + str(iou_threshold[j]), iou_[j] / (i + 1), epoch) model.eval() loss_valid_ = 0. iou_valid_2, iou_valid_4, iou_valid_6, iou_valid_8, iou_valid_9 = 0., 0., 0., 0., 0. for j, v_data in enumerate(valid_data): feat, img, alpha_gt = v_data feat, img, alpha_gt = feat.to(device), img.to(device), alpha_gt.to( device) alpha = model(feat, img) loss = criterion(alpha, alpha_gt) iou_2 = IOU(alpha_gt, alpha, 0.2) iou_4 = IOU(alpha_gt, alpha, 0.4) iou_6 = IOU(alpha_gt, alpha, 0.6) iou_8 = IOU(alpha_gt, alpha, 0.8) iou_9 = IOU(alpha_gt, alpha, 0.9) loss_valid_ += loss.item() iou_valid_2 += iou_2 iou_valid_4 += iou_4 iou_valid_6 += iou_6 iou_valid_8 += iou_8 iou_valid_9 += iou_9 print('loss: %.3f' % (loss_valid_ / (j + 1))) trainlog.add_scalar('loss_val', loss_valid_ / (j + 1), epoch) trainlog.add_scalar('iou_val.2', iou_valid_2 / (j + 1), epoch) trainlog.add_scalar('iou_val.4', iou_valid_4 / (j + 1), epoch) trainlog.add_scalar('iou_val.6', iou_valid_6 / (j + 1), epoch) trainlog.add_scalar('iou_val.8', iou_valid_8 / (j + 1), epoch) trainlog.add_scalar('iou_val.9', iou_valid_9 / (j + 1), epoch) print('Epoch %d validation, avg_loss: %.3f' % (epoch, loss_valid_ / (j + 1))) model.train()