def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http' : saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in currendt model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print (name) if name in saved_state_dict and param.size() == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model=nn.DataParallel(model) model.cuda() cudnn.benchmark = True # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D = nn.DataParallel(model_D) model_D.train() model_D.cuda() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = list(range(train_dataset_size)) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.module.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv : try: _, batch = trainloader_remain_iter.next() except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = trainloader_remain_iter.next() # only access to img images, _, _, _ = batch images = Variable(images).cuda() pred = interp(model(images)) pred_remain = pred.detach() mask1=F.softmax(pred,dim=1).data.cpu().numpy() id2 = np.argmax(mask1, axis=1)#10, 321, 321) D_out = interp(model_D(F.softmax(pred,dim=1))) D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool) loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain)) loss_semi_adv = loss_semi_adv/args.iter_size #loss_semi_adv.backward() loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv if args.lambda_semi <= 0 or i_iter < args.semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < args.mask_T) #print semi_ignore_mask.shape 10,321,321 map2 = np.zeros([pred.size()[0], id2.shape[1], id2.shape[2]]) for k in range(pred.size()[0]): for i in range(id2.shape[1]): for j in range(id2.shape[2]): map2[k][i][j] = mask1[k][id2[k][i][j]][i][j] semi_ignore_mask = (map2 < 0.999999) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = 255 semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size print('semi ratio: {:.4f}'.format(semi_ratio)) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = args.lambda_semi * loss_calc(pred, semi_gt) loss_semi = loss_semi/args.iter_size loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None # train with source try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda() ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels) D_out = interp(model_D(F.softmax(pred,dim=1))) loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss/args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if args.D_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0) D_out = interp(model_D(F.softmax(pred,dim=1))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D/args.iter_size/2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] # train with gt # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda() ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D/args.iter_size/2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] optimizer.step() optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'.format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps-1: print( 'save model ...') torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'.pth')) torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter!=0: print ('taking snapshot ...') torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'.pth')) torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth')) end = timeit.default_timer() print(end-start,'seconds')
def main(): # LD ADD start from dataset.LiverDataset.liver_dataset import LiverDataset user_name = 'give' validation_interval = 800 max_steps = 1000000000 batch_size = 1 n_neighboringslices = 5 input_size = 400 output_size = 400 slice_type = 'axial' oversample = False # reset_counter = args.reset_counter label_of_interest = 1 label_required = 0 magic_number = 26.91 max_slice_tries_val = 0 max_slice_tries_train = 2 fuse_labels = True apply_crop = False train_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_2" test_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1" train_dataset = LiverDataset(data_dir=train_data_dir, slice_type=slice_type, n_neighboringslices=n_neighboringslices, input_size=input_size, oversample=oversample, label_of_interest=label_of_interest, label_required=label_required, max_slice_tries=max_slice_tries_train, fuse_labels=fuse_labels, apply_crop=apply_crop, interval=validation_interval, is_training=True, batch_size=batch_size, data_augmentation=False) val_dataset = LiverDataset(data_dir=test_data_dir, slice_type=slice_type, n_neighboringslices=n_neighboringslices, input_size=input_size, oversample=oversample, label_of_interest=label_of_interest, label_required=label_required, max_slice_tries=max_slice_tries_val, fuse_labels=fuse_labels, apply_crop=apply_crop, interval=validation_interval, is_training=False, batch_size=batch_size) # LD ADD end # LD build for summary # training_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'train')) # val_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'val')) # dice_placeholder = tf.placeholder(tf.float32, [], name='dice') # loss_placeholder = tf.placeholder(tf.float32, [], name='loss') # # image_placeholder = tf.placeholder(tf.float32, [400*2, 400*2], name='image') # # prediction_placeholder = tf.placeholder(tf.float32, [400*2, 400*2], name='prediction') # tf.summary.scalar('dice', dice_placeholder) # tf.summary.scalar('loss', loss_placeholder) # # tf.summary.image('image', image_placeholder, max_outputs=1) # # tf.summary.image('prediction', prediction_placeholder, max_outputs=1) # summary_op = tf.summary.merge_all() # config = tf.ConfigProto() # config.gpu_options.allow_growth = True # sess = tf.Session(config=config) perfix_name = 'Liver' h, w = map(int, args.input_size.split(',')) # input_size = (h, w) cudnn.enabled = True gpu = args.gpu # create network model = Res_Deeplab(num_classes=args.num_classes, input_channel=1, slice_num=n_neighboringslices, gpu_id=args.gpu) if RESTORE_FLAG: # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True # LD delete ''' # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) ''' if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # LD delete ''' train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = range(train_dataset_size) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) ''' # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # LD delete ''' # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99)) optimizer_D.zero_grad() ''' # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size, input_size), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size, input_size), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size, input_size), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 loss_list = [] for i_iter in range(iter_start, args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 num_prediction = 0 num_ground_truth = 0 num_intersection = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) # LD delete ''' optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) ''' for sub_i in range(args.iter_size): # train G # LD delete ''' # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv : try: _, batch = trainloader_remain_iter.next() except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = trainloader_remain_iter.next() # only access to img images, _, _, _ = batch images = Variable(images).cuda(args.gpu) pred = interp(model(images)) pred_remain = pred.detach() D_out = interp(model_D(F.softmax(pred))) D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool) loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain)) loss_semi_adv = loss_semi_adv/args.iter_size #loss_semi_adv.backward() loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv if args.lambda_semi <= 0 or i_iter < args.semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < args.mask_T) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = 255 semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size print('semi ratio: {:.4f}'.format(semi_ratio)) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu) loss_semi = loss_semi/args.iter_size loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None ''' # train with source # LD delete ''' try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) ''' batch_image, batch_label = train_dataset.get_next_batch() batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2)) # batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1) # print('Shape: ', np.shape(batch_image)) batch_image_torch = torch.Tensor(batch_image) images = Variable(batch_image_torch).cuda(args.gpu) # LD delete # ignore_mask = (labels.numpy() == 255) # print('image size is: ', images.size()) pred = model(images) # print('pred shape is ', pred.size()) pred = interp(pred) pred_ny = pred.data.cpu().numpy() pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1)) pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3)) # prepare for dice # print('Shape of gt is: ', np.shape(batch_label)) # print('Shape of pred is: ', np.shape(pred_ny)) # print('Shape of pred_label is: ', np.shape(pred_label_ny)) num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8)) num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8)) num_intersection += np.sum( np.asarray( np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8)) # num_intersection += np.sum(np.asarray(batch_label >= 1, np.uint8) == np.asarray(pred_label_ny, np.uint8)) loss_seg = loss_calc(pred, batch_label, args.gpu) # LD delete ''' D_out = interp(model_D(F.softmax(pred))) loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred ''' loss = loss_seg # print('Loss is: ', loss) # proper normalization loss = loss / args.iter_size loss.backward() # print('Loss of numpy is: ', loss_seg.data.cpu().numpy()) # print('Loss of numpy of zero is: ', loss_seg.data.cpu().numpy()) loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size loss_list.append(loss_seg_value) # loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size # train D # LD delete ''' # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if args.D_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0) D_out = interp(model_D(F.softmax(pred))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D/args.iter_size/2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] # train with gt # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D/args.iter_size/2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] ''' optimizer.step() # optimizer_D.step() dice = (2 * num_intersection + 1e-7) / (num_prediction + num_ground_truth + 1e-7) print('exp = {}'.format(args.snapshot_dir)) print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format( i_iter, args.num_steps, loss_seg_value)) print( 'dice: %.4f, num_prediction: %d, num_ground_truth: %d, num_intersection: %d' % (dice, num_prediction, num_ground_truth, num_intersection)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, perfix_name + str(args.num_steps) + '.pth')) # torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, perfix_name +str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') # torch.save(model.state_dict(), osp.join(args.snapshot_dir, perfix_name + str(i_iter)+'.pth')) save_model(model, args.snapshot_dir, perfix_name, i_iter, 2) # torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, perfix_name +str(i_iter)+'_D.pth')) # if i_iter % UPDATE_TENSORBOARD_INTERVAL and i_iter != 0: # # update tensorboard # feed_dict = { # dice_placeholder: dice, # loss_placeholder: np.mean(loss_list) # } # summery_value = sess.run(summary_op, feed_dict) # training_summary.add_summary(summery_value, i_iter) # training_summary.flush() # # # for validation # val_num_prediction = 0 # val_num_ground_truth = 0 # val_num_intersection = 0 # loss_list = [] # # for _ in range(VAL_EXECUTE_TIMES): # batch_image, batch_label = val_dataset.get_next_batch() # batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2)) # batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1) # # print('Shape: ', np.shape(batch_image)) # batch_image_torch = torch.Tensor(batch_image) # images = Variable(batch_image_torch).cuda(args.gpu) # # # LD delete # # ignore_mask = (labels.numpy() == 255) # pred = interp(model(images)) # pred_ny = pred.data.cpu().numpy() # pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1)) # pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3)) # val_num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8)) # val_num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8)) # val_num_intersection += np.sum(np.asarray(np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8)) # # loss_seg = loss_calc(pred, batch_label, args.gpu) # loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size # loss_list.append(loss_seg) # dice = (2 * val_num_intersection + 1e-7) / (val_num_prediction + val_num_ground_truth + 1e-7) # feed_dict = { # dice_placeholder: dice, # loss_placeholder: np.mean(loss_list) # } # summery_value = sess.run(summary_op, feed_dict) # val_summary.add_summary(summery_value, i_iter) # val_summary.flush() # loss_list = [] training_summary.close() val_summary.close() end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # initialize parameters num_steps = args.num_steps batch_size = args.batch_size lr = args.lr save_cp = args.save_cp img_scale = args.scale val_percent = args.val / 100 # data input dataset = BasicDataset(IMG_DIRECTORY, MASK_DIRECTORY, img_scale) n_val = int(len(dataset) * val_percent) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) tcga_dataset = UnlabeledDataset(TCGA_DIRECTORY) n_unlabeled = len(tcga_dataset) # create network logger = logging.getLogger() logger.setLevel(logging.INFO) #logger.addHandler(logging.StreamHandler()) logging.info('Using device %s' % str(device)) logging.info('Network %s' % args.mod) logging.info('''Starting training: Num_steps: %.2f Batch size: %.2f Learning rate: %.4f_transform Training size: %.0f Validation size: %.0f Unlabeled size: %.0f Checkpoints: %s Device: %s Scale: %.2f ''' % (num_steps, batch_size, lr, n_train, n_val, n_unlabeled, str(save_cp), str(device.type), img_scale)) if args.mod == 'unet': net = UNet(n_channels=3, n_classes=NUM_CLASSES) print('channels = %d , classes = %d' % (net.n_channels, net.n_classes)) elif args.mod == 'modified_unet': net = modified_UNet(n_channels=3, n_classes=NUM_CLASSES) print('channels = %d , classes = %d' % (net.n_channels, net.n_classes)) elif args.mod == 'deeplabv3': net = DeepLabV3(nclass=NUM_CLASSES, pretrained_base=False) print('channels = 3 , classes = %d' % net.nclass) elif args.mod == 'deeplabv3plus': net = DeepLabV3Plus(nclass=NUM_CLASSES, pretrained_base=False) print('channels = 3 , classes = %d' % net.nclass) elif args.mod == 'nestedunet': net = NestedUNet(nclass=NUM_CLASSES, deep_supervision=False) print('channels = 3 , classes = %d' % net.nlass) elif args.mod == 'inception3': net = Inception3(n_classes=4, inception_blocks=None, init_weights=True, bilinear=True) print('channels = 3 , classes = %d' % net.n_classes) net.to(device=device) net.train() cudnn.benchmark = True # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) if args.semi_train is None: train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True) else: #read unlabeled data and labeled data train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) trainloader_remain = DataLoader(tcga_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) #trainloader_gt = data.DataLoader(train_gt_dataset, #batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(train_loader) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network #optimizer = optim.SGD(net.optim_parameters(args), #lr=args.learning_rate, momentum=args.momentum,weight_decay=args.weight_decay) optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) optimizer.zero_grad() scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 10000, eta_min=1e-6, last_epoch=-1) # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) #optimizer_D = optim.SGD(model_D.parameters(), lr=args.learning_rate_D, momentum=args.momentum,weight_decay=args.weight_decay) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) ''' if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') ''' # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): best_acc = 0 loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() #adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False for param in net.parameters(): param.requires_grad = True # do semi first if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv: try: _, batch = trainloader_remain_iter.__next__() except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = trainloader_remain_iter.__next__() # only access to img images = batch['image'] images = images.type(torch.FloatTensor) images = Variable(images).cuda() pred = net(images) pred_remain = pred.detach() D_out = interp(model_D(F.softmax(pred, dim=1))) D_out_sigmoid = torch.sigmoid( D_out).data.cpu().numpy().squeeze(axis=1) #D_out_sigmoid = torch.sigmoid(D_out).data.cpu().numpy() #ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool) targetr = Variable(torch.ones(D_out.shape)) targetr = Variable(torch.FloatTensor(targetr)).cuda() loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, targetr) loss_semi_adv = loss_semi_adv / args.iter_size #loss_semi_adv.backward() #loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv loss_semi_adv_value += loss_semi_adv.cpu().detach().numpy( ).item() / args.lambda_semi_adv if args.lambda_semi <= 0 or i_iter < args.semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < args.mask_T) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = 255 semi_ratio = 1.0 - float( semi_ignore_mask.sum()) / semi_ignore_mask.size print('semi ratio: {:.4f}'.format(semi_ratio)) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = args.lambda_semi * loss_calc(pred, semi_gt) loss_semi = loss_semi / args.iter_size loss_semi_value += loss_semi.cpu().detach().numpy( ).item() / args.lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None # train with source try: _, batch = trainloader_iter.__next__() except: trainloader_iter = enumerate(train_loader) _, batch = trainloader_iter.__next__() images = batch['image'] labels = batch['mask'] images = images.to(device=device, dtype=torch.float32) labels = labels.to(device=device, dtype=torch.long) labels = labels.squeeze(1) ignore_mask = (labels.cpu().numpy() == 255) #pred = interp(net(images)) pred = net(images) criterion = nn.CrossEntropyLoss() loss_seg = criterion(pred, labels) #loss_seg = loss_calc(pred, labels) D_out = interp(model_D(F.softmax(pred, dim=1))) targetr = Variable(torch.ones(D_out.shape)) targetr = Variable(torch.FloatTensor(targetr)).cuda() #loss_adv_pred = bce_loss(D_out, targetr) if i_iter > args.semi_start_adv: loss_adv_pred = bce_loss(D_out, targetr) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred loss_adv_pred_value += loss_adv_pred.cpu().detach().numpy( ).item() / args.iter_size else: loss = loss_seg # proper normalization loss = loss / args.iter_size loss.backward() optimizer.step() loss_seg_value += loss_seg.cpu().detach().numpy().item( ) / args.iter_size #loss_adv_pred_value += loss_adv_pred.cpu().detach().numpy().item()/args.iter_size # train D # bring back requires_grad if i_iter > args.semi_start_adv and i_iter % 3 == 0: for param in net.parameters(): param.requires_grad = False for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if args.D_remain: pred = torch.cat((pred, pred_remain), 0) #ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0) D_out = interp(model_D(F.softmax(pred, dim=1))) #targetf = Variable(torch.zeros(D_out.shape)) targetf = 0.1 * np.random.rand(D_out.shape[0], D_out.shape[1], D_out.shape[2], D_out.shape[3]) targetf = Variable(torch.FloatTensor(targetf)).cuda() loss_D = bce_loss(D_out, targetf) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().detach().numpy().item() # train with gt # get gt labels try: _, batch = trainloader_iter.__next__() except: trainloader_iter = enumerate(train_loader) _, batch = trainloader_iter.__next__() labels_gt = batch['mask'] D_gt_v = Variable(one_hot(labels_gt)).cuda() ignore_mask_gt = (labels_gt.numpy() == 255).squeeze(axis=1) D_out = interp(model_D(D_gt_v)) #targetr = Variable(torch.ones(D_out.shape)) targetr = 0.1 * np.random.rand(D_out.shape[0], D_out.shape[1], D_out.shape[2], D_out.shape[3]) + 0.9 targetr = Variable(torch.FloatTensor(targetr)).cuda() loss_D = bce_loss(D_out, targetr) loss_D = loss_D / args.iter_size / 2 loss_D.backward() optimizer_D.step() loss_D_value += loss_D.cpu().detach().numpy().item() scheduler.step() print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) ''' if i_iter >= args.num_steps-1: print 'save model ...' torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(args.num_steps)+'.pth')) torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter!=0: print 'taking snapshot ...' torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(i_iter)+'.pth')) torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(i_iter)+'_D.pth')) ''' # save checkpoints if save_cp and (i_iter % 1000) == 0 and (i_iter != 0): try: os.mkdir(DIR_CHECKPOINTS) logging.info('Created checkpoint directory') except OSError: pass torch.save(net.state_dict(), DIR_CHECKPOINTS + 'i_iter_%d.pth' % (i_iter + 1)) logging.info('Checkpoint %d saved !' % (i_iter + 1)) if (i_iter % 1000 == 0) and (i_iter != 0): val_score, accuracy, dice_avr, dice_panck, dice_nuclei, dice_lcell = eval_net( net, val_loader, device, n_val) logging.info('Validation cross entropy: {}'.format(val_score)) if accuracy > best_acc: best_acc = accuracy result_file = open('result.txt', 'a', encoding='utf-8') result_file.write('best_acc = ' + str(best_acc) + '\n' + 'iter = ' + str(i_iter) + '\n') result_file.close
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=16, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) train_sampler_all = data.sampler.SubsetRandomSampler(train_ids) train_gt_sampler_all = data.sampler.SubsetRandomSampler(train_ids) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader_all = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler_all, num_workers=16, pin_memory=True) trainloader_gt_all = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler_all, num_workers=16, pin_memory=True) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=16, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=16, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=16, pin_memory=True) trainloader_remain_iter = iter(trainloader_remain) trainloader_all_iter = iter(trainloader_all) trainloader_iter = iter(trainloader) trainloader_gt_iter = iter(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 #y_real_, y_fake_ = Variable(torch.ones(args.batch_size, 1).cuda()), Variable(torch.zeros(args.batch_size, 1).cuda()) for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_fm_value = 0 loss_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # train with source try: batch = next(trainloader_iter) except: trainloader_iter = iter(trainloader) batch = next(trainloader_iter) images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) #ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels, args.gpu) loss_seg.backward() loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size if i_iter >= args.adv_start: #fm loss calc try: batch = next(trainloader_all_iter) except: trainloader_iter = iter(trainloader_all) batch = next(trainloader_all_iter) images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) #ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) _, D_out_y_pred = model_D(F.softmax(pred)) trainloader_gt_iter = iter(trainloader_gt) batch = next(trainloader_gt_iter) _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) #ignore_mask_gt = (labels_gt.numpy() == 255) _, D_out_y_gt = model_D(D_gt_v) fm_loss = torch.mean( torch.abs( torch.mean(D_out_y_gt, 0) - torch.mean(D_out_y_pred, 0))) loss = loss_seg + args.lambda_fm * fm_loss # proper normalization fm_loss.backward() #loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size loss_fm_value += fm_loss.data.cpu().numpy()[0] / args.iter_size loss_value += loss.data.cpu().numpy()[0] / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() D_out_z, _ = model_D(F.softmax(pred)) y_fake_ = Variable(torch.zeros(D_out_z.size(0), 1).cuda()) loss_D_fake = criterion(D_out_z, y_fake_) # train with gt # get gt labels _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) #ignore_mask_gt = (labels_gt.numpy() == 255) D_out_z_gt, _ = model_D(D_gt_v) #D_out = interp(D_out_x) y_real_ = Variable(torch.ones(D_out_z_gt.size(0), 1).cuda()) loss_D_real = criterion(D_out_z_gt, y_real_) loss_D = loss_D_fake + loss_D_real loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] optimizer.step() optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_D = {3:.3f}'. format(i_iter, args.num_steps, loss_seg_value, loss_D_value)) print('fm_loss: ', loss_fm_value, ' g_loss: ', loss_value) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): # LD ADD start from dataset.LiverDataset.liver_dataset import LiverDataset user_name = 'give' validation_interval = 800 max_steps = 1000000000 batch_size = 1 n_neighboringslices = 5 input_size = 400 output_size = 400 slice_type = 'axial' oversample = False # reset_counter = args.reset_counter label_of_interest = 1 label_required = 0 magic_number = 26.91 max_slice_tries_val = 0 max_slice_tries_train = 2 fuse_labels = True apply_crop = False train_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_2" test_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1" train_dataset = LiverDataset(data_dir=train_data_dir, slice_type=slice_type, n_neighboringslices=n_neighboringslices, input_size=input_size, oversample=oversample, label_of_interest=label_of_interest, label_required=label_required, max_slice_tries=max_slice_tries_train, fuse_labels=fuse_labels, apply_crop=apply_crop, interval=validation_interval, is_training=True, batch_size=batch_size, data_augmentation=True) val_dataset = LiverDataset(data_dir=test_data_dir, slice_type=slice_type, n_neighboringslices=n_neighboringslices, input_size=input_size, oversample=oversample, label_of_interest=label_of_interest, label_required=label_required, max_slice_tries=max_slice_tries_val, fuse_labels=fuse_labels, apply_crop=apply_crop, interval=validation_interval, is_training=False, batch_size=batch_size) # LD ADD end # LD build for summary # training_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'train')) # val_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'val')) # dice_placeholder = tf.placeholder(tf.float32, [], name='dice') # loss_placeholder = tf.placeholder(tf.float32, [], name='loss') # # image_placeholder = tf.placeholder(tf.float32, [400*2, 400*2], name='image') # # prediction_placeholder = tf.placeholder(tf.float32, [400*2, 400*2], name='prediction') # tf.summary.scalar('dice', dice_placeholder) # tf.summary.scalar('loss', loss_placeholder) # # tf.summary.image('image', image_placeholder, max_outputs=1) # # tf.summary.image('prediction', prediction_placeholder, max_outputs=1) # summary_op = tf.summary.merge_all() # config = tf.ConfigProto() # config.gpu_options.allow_growth = True # sess = tf.Session(config=config) perfix_name = 'Liver' h, w = map(int, args.input_size.split(',')) # input_size = (h, w) cudnn.enabled = True gpu = args.gpu # create network model = Res_Deeplab(num_classes=args.num_classes, slice_num=n_neighboringslices, gpu_id=0) if RESTORE_FROM is not None: # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True # LD delete ''' # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) ''' if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # LD delete ''' # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99)) optimizer_D.zero_grad() ''' # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size, input_size), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size, input_size), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size, input_size), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 loss_list = [] from dataset.LiverDataset.medicalImage import preprocessing_agumentation, read_image_file image_path = '/home/give/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1/volume-0.nii' gt_path = '/home/give/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1/segmentation-0.nii' image = read_image_file(image_path) gt_image = read_image_file(gt_path) original_image = np.copy(image) processed_image = preprocessing_agumentation(original_image, input_size) for slice_idx in range(processed_image.shape[2]): # print('%d / %d ' % (slice_idx, processed_image.shape[2])) for j in range(n_neighboringslices): cur_idx = slice_idx - half_num_slice + j if cur_idx < 0: cur_idx = 0 if cur_idx >= processed_image.shape[2]: cur_idx = processed_image.shape[2] - 1 image_input[0, :, :, j] = processed_image[:, :, cur_idx] batch_image, batch_label = train_dataset.get_next_batch() batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2)) # batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1) # print('Batch_images: ', np.shape(batch_image)) batch_image_torch = torch.Tensor(batch_image) images = Variable(batch_image_torch).cuda(args.gpu) # LD delete # ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) pred_ny = pred.data.cpu().numpy() pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1)) pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3)) # prepare for dice # print('Shape of gt is: ', np.shape(batch_label)) # print('Shape of pred is: ', np.shape(pred_ny)) # print('Shape of pred_label is: ', np.shape(pred_label_ny)) num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8)) num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8)) num_intersection += np.sum( np.asarray(np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8)) # num_intersection += np.sum(np.asarray(batch_label >= 1, np.uint8) == np.asarray(pred_label_ny, np.uint8)) loss_seg = loss_calc(pred, batch_label, args.gpu) # LD delete ''' D_out = interp(model_D(F.softmax(pred))) loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred ''' loss = loss_seg # print('Loss is: ', loss) # proper normalization loss = loss / args.iter_size loss.backward() # print('Loss of numpy is: ', loss_seg.data.cpu().numpy()) # print('Loss of numpy of zero is: ', loss_seg.data.cpu().numpy()) loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size loss_list.append(loss_seg_value) # loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size # train D # LD delete ''' # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if args.D_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0) D_out = interp(model_D(F.softmax(pred))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D/args.iter_size/2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] # train with gt # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D/args.iter_size/2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] ''' optimizer.step() # optimizer_D.step() dice = (2 * num_intersection + 1e-7) / (num_prediction + num_ground_truth + 1e-7) print('exp = {}'.format(args.snapshot_dir)) print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format( i_iter, args.num_steps, loss_seg_value)) print( 'dice: %.4f, num_prediction: %d, num_ground_truth: %d, num_intersection: %d' % (dice, num_prediction, num_ground_truth, num_intersection)) # if i_iter % UPDATE_TENSORBOARD_INTERVAL and i_iter != 0: # # update tensorboard # feed_dict = { # dice_placeholder: dice, # loss_placeholder: np.mean(loss_list) # } # summery_value = sess.run(summary_op, feed_dict) # training_summary.add_summary(summery_value, i_iter) # training_summary.flush() # # # for validation # val_num_prediction = 0 # val_num_ground_truth = 0 # val_num_intersection = 0 # loss_list = [] # # for _ in range(VAL_EXECUTE_TIMES): # batch_image, batch_label = val_dataset.get_next_batch() # batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2)) # batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1) # # print('Shape: ', np.shape(batch_image)) # batch_image_torch = torch.Tensor(batch_image) # images = Variable(batch_image_torch).cuda(args.gpu) # # # LD delete # # ignore_mask = (labels.numpy() == 255) # pred = interp(model(images)) # pred_ny = pred.data.cpu().numpy() # pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1)) # pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3)) # val_num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8)) # val_num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8)) # val_num_intersection += np.sum(np.asarray(np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8)) # # loss_seg = loss_calc(pred, batch_label, args.gpu) # loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size # loss_list.append(loss_seg) # dice = (2 * val_num_intersection + 1e-7) / (val_num_prediction + val_num_ground_truth + 1e-7) # feed_dict = { # dice_placeholder: dice, # loss_placeholder: np.mean(loss_list) # } # summery_value = sess.run(summary_op, feed_dict) # val_summary.add_summary(summery_value, i_iter) # val_summary.flush() # loss_list = [] training_summary.close() val_summary.close() end = timeit.default_timer() print(end - start, 'seconds')
def train(log_file, arch, dataset, batch_size, iter_size, num_workers, partial_data, partial_data_size, partial_id, ignore_label, crop_size, eval_crop_size, is_training, learning_rate, learning_rate_d, supervised, lambda_adv_pred, lambda_semi, lambda_semi_adv, mask_t, semi_start, semi_start_adv, d_remain, momentum, not_restore_last, num_steps, power, random_mirror, random_scale, random_seed, restore_from, restore_from_d, eval_every, save_snapshot_every, snapshot_dir, weight_decay, device): settings = locals().copy() import cv2 import torch import torch.nn as nn from torch.utils import data, model_zoo import numpy as np import pickle import torch.optim as optim import torch.nn.functional as F import scipy.misc import sys import os import os.path as osp import pickle from model.deeplab import Res_Deeplab from model.unet import unet_resnet50 from model.deeplabv3 import resnet101_deeplabv3 from model.discriminator import FCDiscriminator from utils.loss import CrossEntropy2d, BCEWithLogitsLoss2d from utils.evaluation import EvaluatorIoU from dataset.voc_dataset import VOCDataSet import logger torch_device = torch.device(device) import time if log_file != '' and log_file != 'none': if os.path.exists(log_file): print('Log file {} already exists; exiting...'.format(log_file)) return with logger.LogFile(log_file if log_file != 'none' else None): if dataset == 'pascal_aug': ds = VOCDataSet(augmented_pascal=True) elif dataset == 'pascal': ds = VOCDataSet(augmented_pascal=False) else: print('Dataset {} not yet supported'.format(dataset)) return print('Command: {}'.format(sys.argv[0])) print('Arguments: {}'.format(' '.join(sys.argv[1:]))) print('Settings: {}'.format(', '.join([ '{}={}'.format(k, settings[k]) for k in sorted(list(settings.keys())) ]))) print('Loaded data') def loss_calc(pred, label): """ This function returns cross entropy loss for semantic segmentation """ # out shape batch_size x channels x h x w -> batch_size x channels x h x w # label shape h x w x 1 x batch_size -> batch_size x 1 x h x w label = label.long().to(torch_device) criterion = CrossEntropy2d() return criterion(pred, label) def lr_poly(base_lr, iter, max_iter, power): return base_lr * ((1 - float(iter) / max_iter)**(power)) def adjust_learning_rate(optimizer, i_iter): lr = lr_poly(learning_rate, i_iter, num_steps, power) optimizer.param_groups[0]['lr'] = lr if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr * 10 def adjust_learning_rate_D(optimizer, i_iter): lr = lr_poly(learning_rate_d, i_iter, num_steps, power) optimizer.param_groups[0]['lr'] = lr if len(optimizer.param_groups) > 1: optimizer.param_groups[1]['lr'] = lr * 10 def one_hot(label): label = label.numpy() one_hot = np.zeros((label.shape[0], ds.num_classes, label.shape[1], label.shape[2]), dtype=label.dtype) for i in range(ds.num_classes): one_hot[:, i, ...] = (label == i) #handle ignore labels return torch.tensor(one_hot, dtype=torch.float, device=torch_device) def make_D_label(label, ignore_mask): ignore_mask = np.expand_dims(ignore_mask, axis=1) D_label = np.ones(ignore_mask.shape) * label D_label[ignore_mask] = ignore_label D_label = torch.tensor(D_label, dtype=torch.float, device=torch_device) return D_label h, w = map(int, eval_crop_size.split(',')) eval_crop_size = (h, w) h, w = map(int, crop_size.split(',')) crop_size = (h, w) # create network if arch == 'deeplab2': model = Res_Deeplab(num_classes=ds.num_classes) elif arch == 'unet_resnet50': model = unet_resnet50(num_classes=ds.num_classes) elif arch == 'resnet101_deeplabv3': model = resnet101_deeplabv3(num_classes=ds.num_classes) else: print('Architecture {} not supported'.format(arch)) return # load pretrained parameters if restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(restore_from) else: saved_state_dict = torch.load(restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) model.load_state_dict(new_params) model.train() model = model.to(torch_device) # init D model_D = FCDiscriminator(num_classes=ds.num_classes) if restore_from_d is not None: model_D.load_state_dict(torch.load(restore_from_d)) model_D.train() model_D = model_D.to(torch_device) print('Built model') if snapshot_dir is not None: if not os.path.exists(snapshot_dir): os.makedirs(snapshot_dir) ds_train_xy = ds.train_xy(crop_size=crop_size, scale=random_scale, mirror=random_mirror, range01=model.RANGE01, mean=model.MEAN, std=model.STD) ds_train_y = ds.train_y(crop_size=crop_size, scale=random_scale, mirror=random_mirror, range01=model.RANGE01, mean=model.MEAN, std=model.STD) ds_val_xy = ds.val_xy(crop_size=eval_crop_size, scale=False, mirror=False, range01=model.RANGE01, mean=model.MEAN, std=model.STD) train_dataset_size = len(ds_train_xy) if partial_data_size != -1: if partial_data_size > partial_data_size: print('partial-data-size > |train|: exiting') return if partial_data == 1.0 and (partial_data_size == -1 or partial_data_size == train_dataset_size): trainloader = data.DataLoader(ds_train_xy, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(ds_train_y, batch_size=batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_remain = None print('|train|={}'.format(train_dataset_size)) print('|val|={}'.format(len(ds_val_xy))) else: #sample partial data if partial_data_size != -1: partial_size = partial_data_size else: partial_size = int(partial_data * train_dataset_size) if partial_id is not None: train_ids = pickle.load(open(partial_id)) print('loading train ids from {}'.format(partial_id)) else: rng = np.random.RandomState(random_seed) train_ids = list(rng.permutation(train_dataset_size)) if snapshot_dir is not None: pickle.dump(train_ids, open(osp.join(snapshot_dir, 'train_id.pkl'), 'wb')) print('|train supervised|={}'.format(partial_size)) print('|train unsupervised|={}'.format(train_dataset_size - partial_size)) print('|val|={}'.format(len(ds_val_xy))) print('supervised={}'.format(list(train_ids[:partial_size]))) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(ds_train_xy, batch_size=batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_remain = data.DataLoader(ds_train_xy, batch_size=batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(ds_train_y, batch_size=batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) testloader = data.DataLoader(ds_val_xy, batch_size=1, shuffle=False, pin_memory=True) print('Data loaders ready') trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(learning_rate), lr=learning_rate, momentum=momentum, weight_decay=weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=learning_rate_d, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() print('Built optimizer') # labels for adversarial training pred_label = 0 gt_label = 1 loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_mask_accum = 0 loss_semi_value = 0 loss_semi_adv_value = 0 t1 = time.time() print('Training for {} steps...'.format(num_steps)) for i_iter in range(num_steps + 1): model.train() model.freeze_batchnorm() optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(iter_size): # train G if not supervised: # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if not supervised and (lambda_semi > 0 or lambda_semi_adv > 0 ) and i_iter >= semi_start_adv and \ trainloader_remain is not None: try: _, batch = next(trainloader_remain_iter) except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = next(trainloader_remain_iter) # only access to img images, _, _, _ = batch images = images.float().to(torch_device) pred = model(images) pred_remain = pred.detach() D_out = model_D(F.softmax(pred, dim=1)) D_out_sigmoid = F.sigmoid( D_out).data.cpu().numpy().squeeze(axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) loss_semi_adv = lambda_semi_adv * bce_loss( D_out, make_D_label(gt_label, ignore_mask_remain)) loss_semi_adv = loss_semi_adv / iter_size #loss_semi_adv.backward() loss_semi_adv_value += float( loss_semi_adv) / lambda_semi_adv if lambda_semi <= 0 or i_iter < semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < mask_t) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = ignore_label semi_ratio = 1.0 - float( semi_ignore_mask.sum()) / semi_ignore_mask.size loss_semi_mask_accum += float(semi_ratio) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = lambda_semi * loss_calc(pred, semi_gt) loss_semi = loss_semi / iter_size loss_semi_value += float(loss_semi) / lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None # train with source try: _, batch = next(trainloader_iter) except: trainloader_iter = enumerate(trainloader) _, batch = next(trainloader_iter) images, labels, _, _ = batch images = images.float().to(torch_device) ignore_mask = (labels.numpy() == ignore_label) pred = model(images) loss_seg = loss_calc(pred, labels) if supervised: loss = loss_seg else: D_out = model_D(F.softmax(pred, dim=1)) loss_adv_pred = bce_loss( D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + lambda_adv_pred * loss_adv_pred loss_adv_pred_value += float(loss_adv_pred) / iter_size # proper normalization loss = loss / iter_size loss.backward() loss_seg_value += float(loss_seg) / iter_size if not supervised: # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if d_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate( (ignore_mask, ignore_mask_remain), axis=0) D_out = model_D(F.softmax(pred, dim=1)) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D / iter_size / 2 loss_D.backward() loss_D_value += float(loss_D) # train with gt # get gt labels try: _, batch = next(trainloader_gt_iter) except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = next(trainloader_gt_iter) _, labels_gt, _, _ = batch D_gt_v = one_hot(labels_gt) ignore_mask_gt = (labels_gt.numpy() == ignore_label) D_out = model_D(D_gt_v) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D / iter_size / 2 loss_D.backward() loss_D_value += float(loss_D) optimizer.step() optimizer_D.step() sys.stdout.write('.') sys.stdout.flush() if i_iter % eval_every == 0 and i_iter != 0: model.eval() with torch.no_grad(): evaluator = EvaluatorIoU(ds.num_classes) for index, batch in enumerate(testloader): image, label, size, name = batch size = size[0].numpy() image = image.float().to(torch_device) output = model(image) output = output.cpu().data[0].numpy() output = output[:, :size[0], :size[1]] gt = np.asarray(label[0].numpy()[:size[0], :size[1]], dtype=np.int) output = output.transpose(1, 2, 0) output = np.asarray(np.argmax(output, axis=2), dtype=np.int) evaluator.sample(gt, output, ignore_value=ignore_label) sys.stdout.write('+') sys.stdout.flush() per_class_iou = evaluator.score() mean_iou = per_class_iou.mean() loss_seg_value /= eval_every loss_adv_pred_value /= eval_every loss_D_value /= eval_every loss_semi_mask_accum /= eval_every loss_semi_value /= eval_every loss_semi_adv_value /= eval_every sys.stdout.write('\n') t2 = time.time() print( 'iter = {:8d}/{:8d}, took {:.3f}s, loss_seg = {:.6f}, loss_adv_p = {:.6f}, loss_D = {:.6f}, loss_semi_mask_rate = {:.3%} loss_semi = {:.6f}, loss_semi_adv = {:.3f}' .format(i_iter, num_steps, t2 - t1, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_mask_accum, loss_semi_value, loss_semi_adv_value)) for i, (class_name, iou) in enumerate(zip(ds.class_names, per_class_iou)): print('class {:2d} {:12} IU {:.2f}'.format( i, class_name, iou)) print('meanIOU: ' + str(mean_iou) + '\n') loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_mask_accum = 0 loss_semi_adv_value = 0 t1 = t2 if snapshot_dir is not None and i_iter % save_snapshot_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) if snapshot_dir is not None: print('save model ...') torch.save( model.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(snapshot_dir, 'VOC_' + str(num_steps) + '_D.pth'))
def main(): # LD ADD start from dataset.LiverDataset.liver_dataset import LiverDataset user_name = 'give' validation_interval = 800 max_steps = 1000000000 batch_size = 1 n_neighboringslices = 5 input_size = 400 output_size = 400 slice_type = 'axial' oversample = False # reset_counter = args.reset_counter label_of_interest = 1 label_required = 0 magic_number = 26.91 max_slice_tries_val = 0 max_slice_tries_train = 2 fuse_labels = True apply_crop = False train_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_2" test_data_dir = "/home/" + user_name + "/Documents/dataset/ISBI2017/media/nas/01_Datasets/CT/LITS/Training_Batch_1" train_dataset = LiverDataset(data_dir=train_data_dir, slice_type=slice_type, n_neighboringslices=n_neighboringslices, input_size=input_size, oversample=oversample, label_of_interest=label_of_interest, label_required=label_required, max_slice_tries=max_slice_tries_train, fuse_labels=fuse_labels, apply_crop=apply_crop, interval=validation_interval, is_training=True, batch_size=batch_size, data_augmentation=True) val_dataset = LiverDataset(data_dir=test_data_dir, slice_type=slice_type, n_neighboringslices=n_neighboringslices, input_size=input_size, oversample=oversample, label_of_interest=label_of_interest, label_required=label_required, max_slice_tries=max_slice_tries_val, fuse_labels=fuse_labels, apply_crop=apply_crop, interval=validation_interval, is_training=False, batch_size=batch_size) # LD ADD end # LD build for summary training_summary = tf.summary.FileWriter(os.path.join( SUMMARY_DIR, 'train')) val_summary = tf.summary.FileWriter(os.path.join(SUMMARY_DIR, 'val')) dice_placeholder = tf.placeholder(tf.float32, [], name='dice') loss_placeholder = tf.placeholder(tf.float32, [], name='loss') tf.summary.scalar('dice', dice_placeholder) tf.summary.scalar('loss', loss_placeholder) summary_op = tf.summary.merge_all() config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) perfix_name = 'Liver' h, w = map(int, args.input_size.split(',')) # input_size = (h, w) cudnn.enabled = True gpu = args.gpu # create network model = Res_Deeplab(num_classes=args.num_classes, slice_num=n_neighboringslices, gpu_id=0) if RESTORE_FROM is not None: # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True # LD delete ''' # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) ''' if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # LD delete ''' # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99)) optimizer_D.zero_grad() ''' # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size, input_size), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size, input_size), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size, input_size), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 loss_list = [] dice_list = [] for i_iter in range(iter_start, args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 num_prediction = 0 num_ground_truth = 0 num_intersection = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) # model.train(True) # LD delete ''' optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) ''' for sub_i in range(args.iter_size): # LD delete ''' try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) ''' batch_image, batch_label = train_dataset.get_next_batch() batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2)) # batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1) # print('Batch_images: ', np.shape(batch_image)) batch_image_torch = torch.Tensor(batch_image) images = (batch_image_torch).cuda(args.gpu) # LD delete # ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) pred_ny = pred.data.cpu().numpy() pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1)) pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3)) cur_prediction = np.sum(np.asarray(pred_label_ny, np.uint8)) cur_grount_truth = np.sum(np.asarray(batch_label >= 1, np.uint8)) cur_intersection = np.sum( np.asarray( np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8)) cur_dice = (2 * cur_intersection + 1e-7) / (cur_prediction + cur_grount_truth + 1e-7) dice_list.append(cur_dice) num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8)) num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8)) num_intersection += np.sum( np.asarray( np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8)) # num_intersection += np.sum(np.asarray(batch_label >= 1, np.uint8) == np.asarray(pred_label_ny, np.uint8)) loss_seg = loss_calc(pred, batch_label, args.gpu) # LD delete ''' D_out = interp(model_D(F.softmax(pred))) loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred ''' loss = loss_seg # print('Loss is: ', loss) # proper normalization loss = loss / args.iter_size loss.backward() # print('Loss of numpy is: ', loss_seg.data.cpu().numpy()) # print('Loss of numpy of zero is: ', loss_seg.data.cpu().numpy()) loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size loss_list.append(loss_seg_value) # loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size optimizer.step() # optimizer_D.step() dice = (2 * num_intersection + 1e-7) / (num_prediction + num_ground_truth + 1e-7) print('exp = {}'.format(args.snapshot_dir)) print('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}'.format( i_iter, args.num_steps, loss_seg_value)) print( 'dice: %.4f, num_prediction: %d, num_ground_truth: %d, num_intersection: %d' % (dice, num_prediction, num_ground_truth, num_intersection)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, perfix_name + str(args.num_steps) + '.pth')) # torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, perfix_name +str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') # torch.save(model.state_dict(), osp.join(args.snapshot_dir, perfix_name + str(i_iter)+'.pth')) save_model(model, args.snapshot_dir, perfix_name, i_iter, 2) # torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, perfix_name +str(i_iter)+'_D.pth')) # if i_iter % UPDATE_TENSORBOARD_INTERVAL and i_iter != 0: # update tensorboard feed_dict = { dice_placeholder: dice, loss_placeholder: np.mean(loss_list) } summery_value = sess.run(summary_op, feed_dict) training_summary.add_summary(summery_value, i_iter) training_summary.flush() loss_list = [] dice_list = [] # for validation # val_num_prediction = 0 # val_num_ground_truth = 0 # val_num_intersection = 0 # loss_list = [] # # model.train(False) # for idx in range(VAL_EXECUTE_TIMES): # print(idx) # batch_image, batch_label = val_dataset.get_next_batch() # batch_image = np.transpose(batch_image, axes=(0, 3, 1, 2)) # # batch_image = np.concatenate([batch_image, batch_image, batch_image], axis=1) # # print('Shape: ', np.shape(batch_image)) # batch_image_torch = torch.Tensor(batch_image) # images = Variable(batch_image_torch).cuda(args.gpu) # # # LD delete # # ignore_mask = (labels.numpy() == 255) # pred = interp(model(images)) # pred_ny = pred.data.cpu().numpy() # pred_ny = np.transpose(pred_ny, axes=(0, 2, 3, 1)) # pred_label_ny = np.squeeze(np.argmax(pred_ny, axis=3)) # val_num_prediction += np.sum(np.asarray(pred_label_ny, np.uint8)) # val_num_ground_truth += np.sum(np.asarray(batch_label >= 1, np.uint8)) # val_num_intersection += np.sum(np.asarray(np.logical_and(batch_label >= 1, pred_label_ny >= 1), np.uint8)) # # loss_seg = loss_calc(pred, batch_label, args.gpu) # loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size # loss_list.append(loss_seg) # dice = (2 * val_num_intersection + 1e-7) / (val_num_prediction + val_num_ground_truth + 1e-7) # feed_dict = { # dice_placeholder: dice, # loss_placeholder: np.mean(loss_list) # } # print('validation: dice:%.4f, loss: %.4f' % (dice, np.mean(loss_list))) # summery_value = sess.run(summary_op, feed_dict) # val_summary.add_summary(summery_value, i_iter) # val_summary.flush() # loss_list = [] # print('\n') training_summary.close() val_summary.close() end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu # create network model = DeeplabMulti(num_classes=args.num_classes) #model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from, map_location='cuda:0') # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) #summary(model,(3,7,7)) cudnn.benchmark = True # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) #summary(model_D, (21,321,321)) #quit() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = cityscapesDataSet(max_iters=args.num_steps * args.iter_size * args.batch_size, scale=args.random_scale) train_dataset_size = len(train_dataset) train_gt_dataset = cityscapesDataSet(max_iters=args.num_steps * args.iter_size * args.batch_size, scale=args.random_scale) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: # sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = list(range(train_dataset_size)) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, sampler=train_sampler, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, sampler=train_remain_sampler, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, sampler=train_gt_sampler, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) trainloader_remain_iter = enumerate(trainloader) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 loss_laplacian = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv: try: _, batch = trainloader_remain_iter.__next__() except: trainloader_remain_iter = enumerate(trainloader) _, batch = trainloader_remain_iter.__next__() # only access to img images, _, _, _ = batch images = Variable(images).cuda(args.gpu) try: pred = interp(model(images)) except RuntimeError as exception: if "out of memory" in str(exception): print("WARNING: out of memory") if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() else: raise exception pred_remain = pred.detach() D_out = interp(model_D(F.softmax(pred))) D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze( axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) loss_semi_adv = args.lambda_semi_adv * bce_loss( D_out, make_D_label(gt_label, ignore_mask_remain)) loss_semi_adv = loss_semi_adv / args.iter_size #loss_semi_adv.backward() loss_semi_adv_value += loss_semi_adv.data.cpu().numpy( ) / args.lambda_semi_adv if args.lambda_semi <= 0 or i_iter < args.semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < args.mask_T) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = 255 semi_ratio = 1.0 - float( semi_ignore_mask.sum()) / semi_ignore_mask.size print('semi ratio: {:.4f}'.format(semi_ratio)) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = args.lambda_semi * loss_calc( pred, semi_gt, args.gpu) loss_semi = loss_semi / args.iter_size loss_semi_value += loss_semi.data.cpu().numpy( ) / args.lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None # train with source try: _, batch = trainloader_iter.__next__() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.__next__() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) ignore_mask = (labels.numpy() == 255) try: pred = interp(model(images)) except RuntimeError as exception: if "out of memory" in str(exception): print("WARNING: out of memory") if hasattr(torch.cuda, 'empty_cache'): torch.cuda.empty_cache() else: raise exception for i in range(1): imagess = torch.zeros(1280, 720).cuda() for j in range(19): try: imagess += pred[i, j, :, :].reshape(1280, 720) except IndexError: pass try: label = labels[i, :, :].reshape(1280, 720).cuda() except IndexError: pass imagess = torch.from_numpy( cv2.Laplacian(imagess.cpu().detach().numpy(), -1)).cuda() labell = torch.from_numpy( cv2.Laplacian(label.cpu().detach().numpy(), -1)).cuda() imagess = imagess.reshape(1, 1, 1280, 720) labell = labell.reshape(1, 1, 1280, 720) l = bce_loss(imagess, labell) loss_laplacian = l loss_seg = loss_calc(pred, labels, args.gpu) D_out = interp(model_D(F.softmax(pred))) loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred - loss_laplacian # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( ) / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if args.D_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain), axis=0) D_out = interp(model_D(F.softmax(pred))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() # train with gt # get gt labels try: _, batch = trainloader_gt_iter.__next__() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.__next__() _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() optimizer.step() optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}, loss_laplacian = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value, loss_laplacian)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'CITY_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'CITY_' + str(args.num_steps) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'CITY_' + str(i_iter) + '_D.pth')) #torch.cuda.empty_cache() end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu np.random.seed(args.random_seed) # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # load dataset train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) # labeled data train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) # unlabeled data train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_D_ul_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) # creating 2nd discriminator as a copy of the 1st one if i_iter == args.discr_split: model_D_ul = FCDiscriminator(num_classes=args.num_classes) model_D_ul.load_state_dict(net_D.state_dict()) model_D_ul.train() model_D_ul.cuda(args.gpu) optimizer_D_ul = optim.Adam(model_D_ul.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) # start training 2nd discriminator after specified number of steps if i_iter >= args.discr_split: optimizer_D_ul.zero_grad() adjust_learning_rate_D(optimizer_D_ul, i_iter) for sub_i in range(args.iter_size): # train Segmentation # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # don't accumulate grads in D_ul, in case split has already been made if i_iter >= args.discr_split: for param in model_D_ul.parameters(): param.requires_grad = False # do semi-supervised training first if args.lambda_semi_adv > 0 and i_iter >= args.semi_start_adv: try: _, batch = trainloader_remain_iter.next() except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = trainloader_remain_iter.next() # only access to img images, _, _, _ = batch images = Variable(images).cuda(args.gpu) pred = interp(model(images)) pred_remain = pred.detach() # choose discriminator depending on the iteration if i_iter >= args.discr_split: D_out = interp(model_D_ul(F.softmax(pred))) else: D_out = interp(model_D(F.softmax(pred))) D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze( axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype( np.bool) # adversarial loss loss_semi_adv = args.lambda_semi_adv * bce_loss( D_out, make_D_label(gt_label, ignore_mask_remain, args.gpu)) loss_semi_adv = loss_semi_adv / args.iter_size # true loss value without multiplier loss_semi_adv_value += loss_semi_adv.data.cpu().numpy( ) / args.lambda_semi_adv loss_semi_adv.backward() else: loss_semi = None loss_semi_adv = None # train with labeled images try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) D_out = interp(model_D(F.softmax(pred))) # computing loss loss_seg = loss_calc(pred, labels, args.gpu) loss_adv_pred = bce_loss( D_out, make_D_label(gt_label, ignore_mask, args.gpu)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( ) / args.iter_size # train D and D_ul # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True if i_iter >= args.discr_split: for param in model_D_ul.parameters(): param.requires_grad = True # train D with pred pred = pred.detach() # before split, traing D with both labeled and unlabeled if args.D_remain and i_iter < args.discr_split and ( args.lambda_semi > 0 or args.lambda_semi_adv > 0): pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate((ignore_mask, ignore_mask_remain), axis=0) D_out = interp(model_D(F.softmax(pred))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask, args.gpu)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() # train D_ul with pred on unlabeled if i_iter >= args.discr_split and (args.lambda_semi > 0 or args.lambda_semi_adv > 0): D_ul_out = interp(model_D_ul(F.softmax(pred_remain))) loss_D_ul = bce_loss( D_ul_out, make_D_label(pred_label, ignore_mask_remain, args.gpu)) loss_D_ul = loss_D_ul / args.iter_size / 2 loss_D_ul.backward() loss_D_ul_value += loss_D_ul.data.cpu().numpy() # get gt labels try: _, batch = trainloader_gt_iter.next() except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = trainloader_gt_iter.next() images_gt, labels_gt, _, _ = batch images_gt = Variable(images_gt).cuda(args.gpu) with torch.no_grad(): pred_l = interp(model(images_gt)) # train D with gt D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt, args.gpu)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() # train D_ul with pseudo_gt (gt are substituted for pred) if i_iter >= args.discr_split: D_ul_out = interp(model_D_ul(F.softmax(pred_l))) loss_D_ul = bce_loss( D_ul_out, make_D_label(gt_label, ignore_mask_gt, args.gpu)) loss_D_ul = loss_D_ul / args.iter_size / 2 loss_D_ul.backward() loss_D_ul_value += loss_D_ul.data.cpu().numpy() optimizer.step() optimizer_D.step() if i_iter >= args.discr_split: optimizer_D_ul.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_D_ul={5:.3f}, loss_semi = {6:.3f}, loss_semi_adv = {7:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_D_ul_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( net.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' + str(args.random_seed) + '.pth')) torch.save( net_D.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' + str(args.random_seed) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( net.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(i_iter) + '_' + str(args.random_seed) + '.pth')) torch.save( net_D.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(i_iter) + '_' + str(args.random_seed) + '_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu # create network model = Res_Deeplab(num_classes=args.num_classes, mode=args.mode) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print name if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True bce_loss = BCEWithLogitsLoss2d() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCClsDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=True, mirror=True, mean=IMG_MEAN) train_dataset_size = len(train_dataset) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = range(train_dataset_size) np.random.seed(args.seed) np.random.shuffle(train_ids) #print(train_ids) pickle.dump( train_ids, open( osp.join(args.snapshot_dir, 'train_id_seed_' + str(args.seed) + '_.pkl'), 'wb')) train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_iter = enumerate(trainloader) # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() for i_iter in range(args.num_steps): loss_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) for sub_i in range(args.iter_size): # train with source try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, cls_label, _, _ = batch images = Variable(images).cuda(args.gpu) cls_pred = model(images) cls_label = Variable(torch.FloatTensor(cls_label)).cuda(args.gpu) loss = bce_loss(torch.unsqueeze(torch.unsqueeze(cls_pred, 2), 3), cls_label) # proper normalization loss = loss / args.iter_size loss.backward() loss_value += loss.data.cpu().numpy() / args.iter_size optimizer.step() print('iter = {0:8d}/{1:8d}, loss = {2:.3f}'.format( i_iter, args.num_steps, loss_value)) if i_iter >= args.num_steps - 1: print 'save model ...' torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_classifier_VOCcls_pd_' + str(args.partial_data) + '_seed_' + str(args.seed) + '_' + str(args.num_steps) + '.pth')) if i_iter % args.save_pred_every == 0 and i_iter != 0: print 'taking snapshot ...' torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_classifier_VOCcls_pd_' + str(args.partial_data) + '_seed_' + str(args.seed) + '_' + str(i_iter) + '.pth')) end = timeit.default_timer() print end - start, 'seconds'
def main(): models = { 'resnet101': lambda: PSPNet(n_classes=21, sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101') } h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True # create network model = models['resnet101']() # load pretrained parameters if args.restore_from[:4] == 'http': #saved_state_dict = torch.load(args.restore_from) saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model = nn.DataParallel(model) model.cuda() cudnn.benchmark = True # init D # model_D = FCDiscriminator(num_classes=args.num_classes) # if args.restore_from_D is not None: # model_D.load_state_dict(torch.load(args.restore_from_D)) # # model_D = nn.DataParallel(model_D) # model_D.train() # model_D.cuda() if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data == 0: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network # optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9,0.99)) # optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 seg_criterion = NLL2d().cuda() cls_criterion = nn.BCEWithLogitsLoss(weight=None) for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) # optimizer_D.zero_grad() # adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # train with source try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _, y_cls = batch labels = Variable(labels.long()).cuda() y_cls = Variable(y_cls.float()).cuda() images = Variable(images).cuda() #ignore_mask = (labels.numpy() == 255) out, out_cls = model(images) seg_loss, cls_loss = seg_criterion(out, labels), cls_criterion( out_cls, y_cls) loss = seg_loss + cls_loss # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += seg_loss.data.cpu().numpy()[0] / args.iter_size optimizer.step() # optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(args.num_steps) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + os.path.abspath(__file__).split('/')[-1].split('.')[0] + '_' + str(i_iter) + '.pth')) #torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+os.path.abspath(__file__).split('/')[-1].split('.')[0]+'_'+str(i_iter)+'_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): # 将参数的input_size 映射到整数,并赋值,从字符串转换到整数二元组 h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = False gpu = args.gpu # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) # 确保模型中参数的格式与要加载的参数相同 # 返回一个字典,保存着module的所有状态(state);parameters和persistent buffers都会包含在字典中,字典的key就是parameter和buffer的 names。 new_params = model.state_dict().copy() for name, param in new_params.items(): # print (name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) # print('copy {}'.format(name)) model.load_state_dict(new_params) # 设置为训练模式 model.train() cudnn.benchmark = True model.cuda(gpu) # init D model_D = FCDiscriminator(num_classes=args.num_classes) if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: # sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = list(range(train_dataset_size)) # ? np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) # 写入文件 train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # ??? # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): print("Iter:", i_iter) loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if args.lambda_semi > 0 and i_iter >= args.semi_start: try: _, batch = next(trainloader_remain_iter) except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = next(trainloader_remain_iter) # only access to img images, _, _, _ = batch images = Variable(images).cuda(gpu) # images = Variable(images).cpu() pred = interp(model(images)) D_out = interp(model_D(F.softmax(pred))) D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze( axis=1) # produce ignore mask semi_ignore_mask = (D_out_sigmoid < args.mask_T) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt[semi_ignore_mask] = 255 semi_ratio = 1.0 - float( semi_ignore_mask.sum()) / semi_ignore_mask.size print('semi ratio: {:.4f}'.format(semi_ratio)) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = args.lambda_semi * loss_calc( pred, semi_gt, args.gpu) loss_semi = loss_semi / args.iter_size loss_semi.backward() loss_semi_value += loss_semi.data.cpu().numpy( )[0] / args.lambda_semi else: loss_semi = None # train with source try: _, batch = next(trainloader_iter) except: trainloader_iter = enumerate(trainloader) _, batch = next(trainloader_iter) images, labels, _, _ = batch images = Variable(images).cuda(gpu) # images = Variable(images).cpu() ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) loss_seg = loss_calc(pred, labels, args.gpu) D_out = interp(model_D(F.softmax(pred))) loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy()[0] / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( )[0] / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() D_out = interp(model_D(F.softmax(pred))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] # train with gt # get gt labels try: _, batch = next(trainloader_gt_iter) except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = next(trainloader_gt_iter) _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) # D_gt_v = Variable(one_hot(labels_gt)).cpu() ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy()[0] optimizer.step() optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = map(int, args.input_size.split(',')) input_size = (h, w) cudnn.enabled = True gpu = args.gpu np.random.seed(args.random_seed) # create network model = Res_Deeplab(num_classes=args.num_classes) # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in new_params.items(): print(name) if name in saved_state_dict and param.size( ) == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) print('copy {}'.format(name)) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # load dataset train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: #sample partial data partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print('loading train ids from {}'.format(args.partial_id)) else: train_ids = np.arange(train_dataset_size) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) # labeled data train_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) train_gt_sampler = data.sampler.SubsetRandomSampler( train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) # unlabeled data train_remain_sampler = data.sampler.SubsetRandomSampler( train_ids[partial_size:]) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # loss/bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') for i_iter in range(args.num_steps): loss_seg_value = 0 loss_unlabeled_seg_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) for sub_i in range(args.iter_size): # train Segmentation # train with labeled images try: _, batch = trainloader_iter.next() except: trainloader_iter = enumerate(trainloader) _, batch = trainloader_iter.next() images, labels, _, _ = batch images = Variable(images).cuda(args.gpu) pred = interp(model(images)) # computing loss loss_seg = loss_calc(pred, labels, args.gpu) # proper normalization loss = loss_seg / args.iter_size loss.backward() loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size # train with unlabeled if args.lambda_semi > 0 and i_iter >= args.semi_start: try: _, batch = trainloader_remain_iter.next() except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = trainloader_remain_iter.next() # only access to img images, _, _, _ = batch images = Variable(images).cuda(args.gpu) pred = interp(model(images)) semi_gt = pred.data.cpu().numpy().argmax(axis=1) semi_gt = torch.FloatTensor(semi_gt) loss_unlabeled_seg = args.lambda_semi * loss_calc( pred, semi_gt, args.gpu) loss_unlabeled_seg = loss_unlabeled_seg / args.iter_size loss_unlabeled_seg.backward() loss_unlabeled_seg_value += loss_unlabeled_seg.data.cpu( ).numpy() / args.lambda_semi else: if args.lambda_semi > 0 and i_iter < args.semi_start: loss_unlabeled_seg_value = 0 else: loss_unlabeled_seg_value = None optimizer.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_unlabeled_seg = {3:.3f} ' .format(i_iter, args.num_steps, loss_seg_value, loss_unlabeled_seg_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_' + str(args.lambda_semi) + '_' + str(args.random_seed) + '.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join( args.snapshot_dir, 'VOC_' + str(i_iter) + '_' + str(args.lambda_semi) + '_' + str(args.random_seed) + '.pth')) end = timeit.default_timer() print(end - start, 'seconds')
def main(): h, w = list(map(int, args.input_size.split(','))) # 321, 321 input_size = (h, w) cudnn.enabled = True # create network model = Res_Deeplab(num_classes=args.num_classes) # num_classes = 21 # load pretrained parameters if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) new_params = model.state_dict().copy() for name, param in list(new_params.items()): # print(name) if name in saved_state_dict and param.size() == saved_state_dict[name].size(): new_params[name].copy_(saved_state_dict[name]) # print(('copy {}'.format(name))) model.load_state_dict(new_params) model.train() model.cuda(args.gpu) cudnn.benchmark = True # init D model_D = FCDiscriminator(num_classes=args.num_classes) # num_classes = 21,全卷积判别模型 if args.restore_from_D is not None: model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cuda(args.gpu) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) if not os.path.exists('logs/'): os.makedirs('logs/') now_time = datetime.now().strftime('%Y-%m-%d-%H:%M:%S') log_file = 'logs/' + now_time + '.txt' file = open(log_file, 'w') # 保存loss train_dataset = VOCDataSet(args.data_dir, args.data_list, args.label_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) train_dataset_size = len(train_dataset) train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, args.label_list, crop_size=input_size, scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) if args.partial_data is None: # 使用全部数据 trainloader = data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, # batch_size = 10 num_workers=5, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=True) else: # sample partial data 部分数据 partial_size = int(args.partial_data * train_dataset_size) if args.partial_id is not None: train_ids = pickle.load(open(args.partial_id)) print(('loading train ids from {}'.format(args.partial_id))) else: train_ids = list(range(train_dataset_size)) np.random.shuffle(train_ids) pickle.dump(train_ids, open(osp.join(args.snapshot_dir, 'train_id.pkl'), 'wb')) # 将train_ids写入train_id.pkl train_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) train_remain_sampler = data.sampler.SubsetRandomSampler(train_ids[partial_size:]) train_gt_sampler = data.sampler.SubsetRandomSampler(train_ids[:partial_size]) trainloader = data.DataLoader(train_dataset, # 数据集中采样输入 batch_size=args.batch_size, sampler=train_sampler, num_workers=3, pin_memory=True) trainloader_remain = data.DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_remain_sampler, num_workers=3, pin_memory=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=args.batch_size, sampler=train_gt_sampler, num_workers=3, pin_memory=True) trainloader_remain_iter = enumerate(trainloader_remain) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # implement model.optim_parameters(args) to handle different models' lr setting # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() # interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') if version.parse(torch.__version__) >= version.parse('0.4.0'): interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) else: interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear') # labels for adversarial training pred_label = 0 gt_label = 1 best_loss = 1 best_epoch = 0 for i_iter in range(args.num_steps): # num_steps = 20000 loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # iter_size = 1 # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first if (args.lambda_semi > 0 or args.lambda_semi_adv > 0) and i_iter >= args.semi_start_adv: try: _, batch = next(trainloader_remain_iter) except: trainloader_remain_iter = enumerate(trainloader_remain) _, batch = next(trainloader_remain_iter) # only access to img 无标签数据 images, _, _, _, _ = batch images = Variable(images).cuda(args.gpu) pred = interp(model(images)) pred_remain = pred.detach() # 返回一个新的Variable,不具有grade D_out = interp(model_D(F.softmax(pred))) D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1) ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool) loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain)) loss_semi_adv = loss_semi_adv/args.iter_size # loss_semi_adv.backward() # print('bug,', loss_semi_adv.data.cpu().numpy()) # loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()[0]/args.lambda_semi_adv loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()/args.lambda_semi_adv if args.lambda_semi <= 0 or i_iter < args.semi_start: loss_semi_adv.backward() loss_semi_value = 0 else: # produce ignore mask semi_ignore_mask = (D_out_sigmoid < args.mask_T) # mask_T = 0.2,阈值 semi_gt = pred.data.cpu().numpy().argmax(axis=1) # 返回维度为1上的最大值的下标 semi_gt[semi_ignore_mask] = 255 semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size # 被忽略的点占的比重 print(('semi ratio: {:.4f}'.format(semi_ratio))) if semi_ratio == 0.0: loss_semi_value += 0 else: semi_gt = torch.FloatTensor(semi_gt) loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu) loss_semi = loss_semi/args.iter_size # loss_semi_value += loss_semi.data.cpu().numpy()[0]/args.lambda_semi loss_semi_value += loss_semi.data.cpu().numpy()/args.lambda_semi loss_semi += loss_semi_adv loss_semi.backward() else: loss_semi = None loss_semi_adv = None # train with source try: _, batch = next(trainloader_iter) except: trainloader_iter = enumerate(trainloader) _, batch = next(trainloader_iter) images, labels, _, _, _ = batch # 有标签数据 images = Variable(images).cuda(args.gpu) ignore_mask = (labels.numpy() == 255) pred = interp(model(images)) # interp上采样 loss_seg = loss_calc(pred, labels, args.gpu) # 语义分割的cross entropy loss # loss_seg_NLL = loss_NLL(pred, labels, args.gpu) # 语义分割的NLLLoss D_out = interp(model_D(F.softmax(pred))) # 得到判别模型输出的判别图 loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # proper normalization loss = loss/args.iter_size loss.backward() # loss_seg_value += loss_seg.data.cpu().numpy()[0]/args.iter_size # loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()[0]/args.iter_size loss_seg_value += loss_seg.data.cpu().numpy()/args.iter_size # loss_seg_value += loss_seg_NLL.data.cpu().numpy()/args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy()/args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() if args.D_remain: pred = torch.cat((pred, pred_remain), 0) ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis=0) D_out = interp(model_D(F.softmax(pred))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D/args.iter_size/2 loss_D.backward() # loss_D_value += loss_D.data.cpu().numpy()[0] loss_D_value += loss_D.data.cpu().numpy() # train with gt # get gt labels try: _, batch = next(trainloader_gt_iter) except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = next(trainloader_gt_iter) _, labels_gt, _, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) # 每个类别一张label图,batch * class * h * w ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) # ground_truth输入判别模型 loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D/args.iter_size/2 loss_D.backward() # loss_D_value += loss_D.data.cpu().numpy()[0] loss_D_value += loss_D.data.cpu().numpy() optimizer.step() optimizer_D.step() print(('exp = {}'.format(args.snapshot_dir))) print(('iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, ' 'loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}'.format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value))) file.write('{0} {1} {2} {3} {4}\n'.format(loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if loss_seg_value < best_loss: # 保存最优模型,删除次优模型 # print('loss:', loss_seg_value, 'best:', best_loss) torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'VOC_epoch_{0}_seg_loss_{1}.pth'.format(i_iter+1, loss_seg_value))) torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_epoch_{0}_seg_loss_{1}_D.pth'.format(i_iter+1, loss_seg_value))) delete_models(best_epoch + 1, best_loss) best_loss = loss_seg_value best_epoch = i_iter if i_iter >= args.num_steps-1: # num_step = 20000 print('save model ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'VOC_'+str(args.num_steps)+'.pth')) torch.save(model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_'+str(args.num_steps)+'_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: # save_pred_every = 5000 print('taking snapshot ...') torch.save(model.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(i_iter)+'.pth')) torch.save(model_D.state_dict(),osp.join(args.snapshot_dir, 'VOC_'+str(i_iter)+'_D.pth')) end = timeit.default_timer() print(end-start, 'seconds') file.close()
def main(): # parse input size h, w = map(int, args.input_size.split(',')) input_size = (h, w) # cudnn.enabled = True # gpu = args.gpu # create segmentation network model = DeepLab(num_classes=args.num_classes) # load pretrained parameters # if args.restore_from[:4] == 'http' : # saved_state_dict = model_zoo.load_url(args.restore_from) # else: # saved_state_dict = torch.load(args.restore_from) # only copy the params that exist in current model (caffe-like) # new_params = model.state_dict().copy() # for name, param in new_params.items(): # if name in saved_state_dict and param.size() == saved_state_dict[name].size(): # new_params[name].copy_(saved_state_dict[name]) # model.load_state_dict(new_params) model.train() model.cpu() # model.cuda(args.gpu) # cudnn.benchmark = True # create discriminator network model_D = Discriminator(num_classes=args.num_classes) # if args.restore_from_D is not None: # model_D.load_state_dict(torch.load(args.restore_from_D)) model_D.train() model_D.cpu() # model_D.cuda(args.gpu) # MILESTONE 1 print("Printing MODELS ...") print(model) print(model_D) # Create directory to save snapshots of the model if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) # Load train data and ground truth labels # train_dataset = VOCDataSet(args.data_dir, args.data_list, crop_size=input_size, # scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) # train_gt_dataset = VOCGTDataSet(args.data_dir, args.data_list, crop_size=input_size, # scale=args.random_scale, mirror=args.random_mirror, mean=IMG_MEAN) # trainloader = data.DataLoader(train_dataset, # batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False) # trainloader_gt = data.DataLoader(train_gt_dataset, # batch_size=args.batch_size, shuffle=True, num_workers=5, pin_memory=False) train_dataset = MyCustomDataset() train_gt_dataset = MyCustomDataset() trainloader = data.DataLoader(train_dataset, batch_size=5, shuffle=True) trainloader_gt = data.DataLoader(train_gt_dataset, batch_size=5, shuffle=True) trainloader_iter = enumerate(trainloader) trainloader_gt_iter = enumerate(trainloader_gt) # MILESTONE 2 print("Printing Loaders") print(trainloader_iter) print(trainloader_gt_iter) # optimizer for segmentation network optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer for discriminator network optimizer_D = optim.Adam(model_D.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D.zero_grad() # MILESTONE 3 print("Printing OPTIMIZERS ...") print(optimizer) print(optimizer_D) # loss/ bilinear upsampling bce_loss = BCEWithLogitsLoss2d() interp = nn.Upsample(size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True) # labels for adversarial training pred_label = 0 gt_label = 1 for i_iter in range(args.num_steps): loss_seg_value = 0 loss_adv_pred_value = 0 loss_D_value = 0 loss_semi_value = 0 loss_semi_adv_value = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, i_iter) optimizer_D.zero_grad() adjust_learning_rate_D(optimizer_D, i_iter) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D.parameters(): param.requires_grad = False # do semi first # if (args.lambda_semi > 0 or args.lambda_semi_adv > 0 ) and i_iter >= args.semi_start_adv : # try: # _, batch = next(trainloader_remain_iter) # except: # trainloader_remain_iter = enumerate(trainloader_remain) # _, batch = next(trainloader_remain_iter) # # only access to img # images, _, _, _ = batch # images = Variable(images).cuda(args.gpu) # pred = interp(model(images)) # pred_remain = pred.detach() # D_out = interp(model_D(F.softmax(pred))) # D_out_sigmoid = F.sigmoid(D_out).data.cpu().numpy().squeeze(axis=1) # ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool) # loss_semi_adv = args.lambda_semi_adv * bce_loss(D_out, make_D_label(gt_label, ignore_mask_remain)) # loss_semi_adv = loss_semi_adv/args.iter_size # #loss_semi_adv.backward() # loss_semi_adv_value += loss_semi_adv.data.cpu().numpy()/args.lambda_semi_adv # if args.lambda_semi <= 0 or i_iter < args.semi_start: # loss_semi_adv.backward() # loss_semi_value = 0 # else: # # produce ignore mask # semi_ignore_mask = (D_out_sigmoid < args.mask_T) # semi_gt = pred.data.cpu().numpy().argmax(axis=1) # semi_gt[semi_ignore_mask] = 255 # semi_ratio = 1.0 - float(semi_ignore_mask.sum())/semi_ignore_mask.size # print('semi ratio: {:.4f}'.format(semi_ratio)) # if semi_ratio == 0.0: # loss_semi_value += 0 # else: # semi_gt = torch.FloatTensor(semi_gt) # loss_semi = args.lambda_semi * loss_calc(pred, semi_gt, args.gpu) # loss_semi = loss_semi/args.iter_size # loss_semi_value += loss_semi.data.cpu().numpy()/args.lambda_semi # loss_semi += loss_semi_adv # loss_semi.backward() # else: # loss_semi = None # loss_semi_adv = None # train with source try: _, batch = next(trainloader_iter) except: trainloader_iter = enumerate(trainloader) _, batch = next(trainloader_iter) images, labels, _, _ = batch images = Variable(images).cpu() # images = Variable(images).cuda(args.gpu) ignore_mask = (labels.numpy() == 255) # segmentation prediction pred = interp(model(images)) # (spatial multi-class) cross entropy loss loss_seg = loss_calc(pred, labels) # loss_seg = loss_calc(pred, labels, args.gpu) # discriminator prediction D_out = interp(model_D(F.softmax(pred))) # adversarial loss loss_adv_pred = bce_loss(D_out, make_D_label(gt_label, ignore_mask)) # multi-task loss # lambda_adv - weight for minimizing loss loss = loss_seg + args.lambda_adv_pred * loss_adv_pred # loss normalization loss = loss / args.iter_size # back propagation loss.backward() loss_seg_value += loss_seg.data.cpu().numpy() / args.iter_size loss_adv_pred_value += loss_adv_pred.data.cpu().numpy( ) / args.iter_size # train D # bring back requires_grad for param in model_D.parameters(): param.requires_grad = True # train with pred pred = pred.detach() # if args.D_remain: # pred = torch.cat((pred, pred_remain), 0) # ignore_mask = np.concatenate((ignore_mask,ignore_mask_remain), axis = 0) D_out = interp(model_D(F.softmax(pred))) loss_D = bce_loss(D_out, make_D_label(pred_label, ignore_mask)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() # train with gt # get gt labels try: _, batch = next(trainloader_gt_iter) except: trainloader_gt_iter = enumerate(trainloader_gt) _, batch = next(trainloader_gt_iter) _, labels_gt, _, _ = batch D_gt_v = Variable(one_hot(labels_gt)).cpu() # D_gt_v = Variable(one_hot(labels_gt)).cuda(args.gpu) ignore_mask_gt = (labels_gt.numpy() == 255) D_out = interp(model_D(D_gt_v)) loss_D = bce_loss(D_out, make_D_label(gt_label, ignore_mask_gt)) loss_D = loss_D / args.iter_size / 2 loss_D.backward() loss_D_value += loss_D.data.cpu().numpy() optimizer.step() optimizer_D.step() print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg = {2:.3f}, loss_adv_p = {3:.3f}, loss_D = {4:.3f}, loss_semi = {5:.3f}, loss_semi_adv = {6:.3f}' .format(i_iter, args.num_steps, loss_seg_value, loss_adv_pred_value, loss_D_value, loss_semi_value, loss_semi_adv_value)) if i_iter >= args.num_steps - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(args.num_steps) + '_D.pth')) break if i_iter % args.save_pred_every == 0 and i_iter != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '.pth')) torch.save( model_D.state_dict(), osp.join(args.snapshot_dir, 'VOC_' + str(i_iter) + '_D.pth')) end = timeit.default_timer() print(end - start, 'seconds')