def __init__(self, batch_norm_mode, depth, model_root_channel=8, img_size=256, batch_size=20, n_channel=1, n_class=2): self.drop_rate = tf.placeholder(tf.float32) self.training = tf.placeholder(tf.bool) self.batch_size = batch_size self.model_channel = model_root_channel self.batch_mode = batch_norm_mode self.depth_n = depth self.X = tf.placeholder(tf.float32, [None, img_size, img_size, n_channel], name='X') self.Y = tf.placeholder(tf.float32, [None, img_size, img_size, n_class], name='Y') self.logits = self.neural_net() self.foreground_predicted, self.background_predicted = tf.split( tf.nn.softmax(self.logits), [1, 1], 3) self.foreground_truth, self.background_truth = tf.split( self.Y, [1, 1], 3) with tf.name_scope('Loss'): # # Cross_Entropy # self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.Y)) # # Dice_Loss self.loss = utils.dice_loss(output=self.logits, target=self.Y) # # Focal_Loss # self.loss=utils.focal_loss(output=self.logits, target=self.Y, use_class=True, gamma=2, smooth=1e-8) with tf.name_scope('Metrics'): self.accuracy = utils.mean_iou(self.foreground_predicted, self.foreground_truth) # TB tf.summary.scalar('loss', self.loss) tf.summary.scalar('accuracy', self.accuracy)
def main(): model = factory('unet') from utils import dice_loss inputs = torch.randn(8, 3, 256, 256) labels = torch.LongTensor(8, 256, 256).random_(1).type(torch.FloatTensor) model = model.cuda().train() x = torch.autograd.Variable(inputs).cuda() y = torch.autograd.Variable(labels).cuda() logits = model.forward(x) loss = dice_loss(logits, y) loss.backward() print(type(model)) print(model) print('logits') print(logits)
def train(NetG, NetD, optimizerG, optimizerD, dataloader, epoch): total_dice = 0 total_g_loss = 0 total_g_loss_dice = 0 total_g_loss_bce = 0 total_d_loss = 0 total_d_loss_penalty = 0 NetG.train() NetD.train() for i, data in enumerate(dataloader, 1): # train D optimizerD.zero_grad() NetD.zero_grad() for p in NetG.parameters(): p.requires_grad = False for p in NetD.parameters(): p.requires_grad = True input, target = Variable(data[0]), Variable(data[1]) input = input.float() target = target.float() if use_cuda: input = input.cuda() target = target.cuda() output = NetG(input) output = F.sigmoid(output) output = output.detach() input_img = input.clone() output_masked = input_img * output if use_cuda: output_masked = output_masked.cuda() result = NetD(output_masked) target_masked = input_img * target if use_cuda: target_masked = target_masked.cuda() target_D = NetD(target_masked) loss_mac = -torch.mean(torch.abs(result - target_D)) loss_mac.backward() # D net gradient_penalty batch_size = target_masked.size(0) gradient_penalty = utils.calc_gradient_penalty(NetD, target_masked, output_masked, batch_size, use_cuda, input.shape) gradient_penalty.backward() optimizerD.step() # train G optimizerG.zero_grad() NetG.zero_grad() for p in NetG.parameters(): p.requires_grad = True for p in NetD.parameters(): p.requires_grad = False output = NetG(input) output = F.sigmoid(output) target_dice = target.view(-1).long() output_dice = output.view(-1) loss_dice = utils.dice_loss(output_dice, target_dice) output_masked = input_img * output if use_cuda: output_masked = output_masked.cuda() result = NetD(output_masked) target_G = NetD(target_masked) loss_G = torch.mean(torch.abs(result - target_G)) loss_G_joint = loss_G + loss_dice loss_G_joint.backward() optimizerG.step() total_dice += 1 - loss_dice.data[0] total_g_loss += loss_G_joint.data[0] total_g_loss_dice += loss_dice.data[0] total_g_loss_bce += loss_G.data[0] total_d_loss += loss_mac.data[0] total_d_loss_penalty += gradient_penalty.data[0] for p in NetG.parameters(): p.requires_grad = True for p in NetD.parameters(): p.requires_grad = True size = len(dataloader) epoch_dice = total_dice / size epoch_g_loss = total_g_loss / size epoch_g_loss_dice = total_g_loss_dice / size epoch_g_loss_bce = total_g_loss_bce / size epoch_d_loss = total_d_loss / size epoch_d_loss_penalty = total_d_loss_penalty / size print_format = [ epoch, conf.epochs, epoch_dice * 100, epoch_g_loss, epoch_g_loss_dice, epoch_g_loss_bce, epoch_d_loss, epoch_d_loss_penalty ] print('===> Training step {}/{} \tepoch_dice: {:.5f}' '\tepoch_g_loss: {:.5f} \tepoch_g_loss_dice: {:.5f}' '\tepoch_g_loss_bce: {:.5f} \tepoch_d_loss: {:.5f}' '\tepoch_d_loss_penalty: {:.5f}'.format(*print_format))
def final_multiscale_roi_align(model, optimizer, scheduler, dataloaders, dataset_sizes, num_epochs=25, viz= False): # Steps ''' Generate all the patches as 28x28 from the MRI scan. do it with anchor boxes. write a Transfrom for that. - for an scan generate 10 RoIs. - the generator will return this:- MRI_scan224, MRI_label224, plus (all anchor boxes of MRI_scan28x28, MRI_label28x28) - if the sample has some lesion, return all the RoIs with that lesion. return just the (x1, y1, x2, y2) of the boxes in 224x224 map. - if the sample has no lesion, then return 10 RoIs of no lesion zone. Then perform these new set of actions on the sub-level data transform:- - it has to take the 224x224 tensor, and the rois, and then do the roi align to generate these level of feature maps. - now, view(-1, m, n) and randomize all the samples, for all (m, n) maps levels. - it has to run a simple algorithm to get the class as 0 or 1 for every patch. Second, get RoI maps for the same 28x28 roi from the feature maps of the CNN using RoI align. and by passing through the deconv nets. - so, the model() nn.module has perform all this. - it has to run deconv nets as pytorch.nn modules for these levels of patches dims to result in uniform 28x28 maps. - concat all the 28x28 predicted masks from these feature levels, make one small 3x3 or 3x3 conv and 1x1 conv until it ends up here. - it has to return 28x28 predictions for all feature levels individually plus the max class voting result from these preds, as one mask plus the classification head Third, frame the loss function with the classifier head and the segmentor head. - train the classifier for all samples. - run a simple algorithm to collect only those samples with non-zero lesion based on the patch classifier label. - run piecewise loss for every patch mask to prediction. Also, double it up with a secondary, loss function. ''' #out dirs base_dir = Path.cwd() / 'outputs' / 'single_scale_roi_align' output_tracking_dir = base_dir / 'output_tracking' logs_dir = base_dir / 'logs' model_dir = base_dir / 'model' model = model.to(device) upsampler = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) logs_dir.mkdir(parents= True, exist_ok= True) model_dir.mkdir(parents= True, exist_ok=True) output_tracking_dir.mkdir(parents=True, exist_ok=True) since = time.time() PATH = str(model_dir / (model.name+'.pth')) epo = 1 if Path(PATH).is_file(): checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epo = checkpoint['epoch'] loss = checkpoint['loss'] scheduler.load_state_dict(checkpoint['scheduler_state_dict']) print('Resuming from epoch ' + str(epo) + ', LOSS: ', loss) best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 3.0 logs_ptr = open(str(logs_dir/ 'train_logs'), 'a') # pdb.set_trace() for epoch in range(epo, epo + num_epochs): epoch_str = 'Epoch {}/{}'.format(epoch, epo + num_epochs - 1) + '\n\n' print(epoch_str) logs_ptr.write(epoch_str) print('-' * 10) try: # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_softmax = 0.0 running_dice = 0.0 # Iterate over data. times = 0 for mini_batch, (inputs, label224, label28) in enumerate(dataloaders[phase]): inputs = inputs.to(device) # labels size is (batch_size, 1, 224, 224) label28 = label28.to(device) label224 = label224.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): log_softmax_outputs28 = model(inputs) # shape of pred28 is (batch_size, 2, 28, 28) softmax_loss = F.nll_loss(log_softmax_outputs28, label28.round().squeeze().long(), weight=class_weights) softmax_outputs28 = torch.exp(log_softmax_outputs28) torch_pred28_prob = get_prob_map28(softmax_outputs28) torch_pred224_prob = upsampler(torch_pred28_prob) rounded_pred224_prob_for_dice = torch.round(torch_pred224_prob) # return format is (batch_size, 1, 224, 224) dice_l = dice_loss(input=rounded_pred224_prob_for_dice, target=label224) # dice_l = dice_loss(input=outputs28, target=mask28) total_loss = 0.7 * dice_l + 0.3 * softmax_loss # backward + optimize only if in training phase if phase == 'train': total_loss.backward() optimizer.step() if phase == 'train': step_str = '{} Step {}- Loss: {:.4f}, Dice Loss: {:.4f}, Softmax Loss: {:.4f}'\ .format(phase, mini_batch + 1,total_loss, dice_l, softmax_loss) print(step_str) logs_ptr.write(step_str+'\n') if phase == 'val' and viz: for item in range(label28.size(0)): # get the path for saving the intermediate outputs epoch_tracking_path = output_tracking_dir / str(epoch) if not epoch_tracking_path.is_dir(): epoch_tracking_path.mkdir(parents=True, exist_ok=False) actual_predicted(label224[item][0].numpy(), rounded_pred224_prob_for_dice[item][0].detach().numpy(), str(epoch_tracking_path / (str(mini_batch * label28.size(0) + item) + '.jpg') ) ) # statistics # running_loss += step_loss.item() * inputs.size(0) running_dice += dice_l.item() * inputs.size(0) running_softmax += softmax_loss.item() * inputs.size(0) # times+=1 # if times==2: # break # end of an epoch # pdb.set_trace() # epoch_loss = running_loss / dataset_sizes[phase] epoch_dice_l = running_dice / dataset_sizes[phase] epoch_softmax = running_softmax / dataset_sizes[phase] epoch_loss = epoch_dice_l + epoch_softmax if phase == 'train': scheduler.step() loss_str = '\n{} Epoch {}: TotalLoss: {:.4f} SoftmaxLoss: {:.4f} Dice Loss: {:.4f} \n'.format( phase, epoch, epoch_loss, epoch_softmax, epoch_dice_l) + '\n' print(loss_str) logs_ptr.write(loss_str + '\n') # deep copy the model if phase == 'val' and epoch_loss > best_loss: print('Val Dice better than Best Dice') best_loss = epoch_loss best_model_wts = copy.deepcopy(model.state_dict()) except: # save model save_model(epoch, best_model_wts, optimizer, scheduler, epoch_loss, PATH) exit(0) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val DICE: {:4f}'.format(best_loss)) # save model save_model(num_epochs, best_model_wts, optimizer, scheduler, loss, PATH)
def experiment3(model, optimizer, scheduler, dataloaders, dataset_sizes, num_epochs=25, viz = False): #out dirs base_dir = Path.cwd() / 'outputs' / 'experiment3' output_tracking_dir = base_dir / 'output_tracking' logs_dir = base_dir / 'logs' model_dir = base_dir / 'model' upsampler = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) model = model.to(device) logs_dir.mkdir(parents= True, exist_ok= True) model_dir.mkdir(parents= True, exist_ok=True) output_tracking_dir.mkdir(parents=True, exist_ok=True) since = time.time() PATH = str(model_dir / (model.name+'.pth')) epo = 1 if Path(PATH).is_file(): checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epo = checkpoint['epoch'] loss = checkpoint['loss'] scheduler.load_state_dict(checkpoint['scheduler_state_dict']) print('Resuming from epoch ' + str(epo) + ', LOSS: ', loss) best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 3.0 logs_ptr = open(str(logs_dir/ 'train_logs'), 'a') # pdb.set_trace() for epoch in range(epo, epo + num_epochs): epoch_str = 'Epoch {}/{}'.format(epoch, epo + num_epochs - 1) + '\n\n' print(epoch_str) logs_ptr.write(epoch_str) print('-' * 10) try: # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_softmax = 0.0 running_dice = 0.0 # Iterate over data. times = 0 for mini_batch, (inputs, label224, _) in enumerate(dataloaders[phase]): inputs = inputs.to(device) # labels size is (batch_size, 1, 224, 224) label224 = label224.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): log_softmax_outputs224 = model(inputs) # shape of pred224 is (batch_size, 2, 224, 224) softmax_loss = F.nll_loss(log_softmax_outputs224, label224.squeeze().long(), weight=class_weights) softmax_outputs224 = torch.exp(log_softmax_outputs224) _, pred224_argmax = torch.max(softmax_outputs224, dim=1, keepdim=True) # (batch_size, 1, 28,28) pred224_argmax = pred224_argmax.float() dice_l = dice_loss(input=pred224_argmax, target=label224) # dice_l = dice_loss(input=outputs28, target=mask28) total_loss = 0.9 * dice_l + 0.1 * softmax_loss # backward + optimize only if in training phase if phase == 'train': total_loss.backward() optimizer.step() if phase == 'train': step_str = '{} Step {}- Loss: {:.4f}, Dice Loss: {:.4f}, Softmax Loss: {:.4f}'\ .format(phase, mini_batch + 1,total_loss, dice_l, softmax_loss) print(step_str) logs_ptr.write(step_str+'\n') if phase == 'val' and viz: for item in range(label224.size(0)): # get the path for saving the intermediate outputs epoch_tracking_path = output_tracking_dir / str(epoch) if not epoch_tracking_path.is_dir(): epoch_tracking_path.mkdir(parents=True, exist_ok=False) actual_predicted(label224[item][0].numpy(), pred224_argmax[item][0].detach().numpy(), str(epoch_tracking_path / (str(mini_batch * label224.size(0) + item) + '.jpg') ) ) # statistics # running_loss += step_loss.item() * inputs.size(0) running_dice += dice_l.item() * inputs.size(0) running_softmax += softmax_loss.item() * inputs.size(0) # times+=1 # if times==2: # break # end of an epoch # pdb.set_trace() # epoch_loss = running_loss / dataset_sizes[phase] epoch_dice_l = running_dice / dataset_sizes[phase] epoch_softmax = running_softmax / dataset_sizes[phase] epoch_loss = epoch_dice_l + epoch_softmax if phase == 'train': scheduler.step() loss_str = '\n{} Epoch {}: TotalLoss: {:.4f} SoftmaxLoss: {:.4f} Dice Loss: {:.4f} \n'.format( phase, epoch, epoch_loss, epoch_softmax, epoch_dice_l) + '\n' print(loss_str) logs_ptr.write(loss_str + '\n') # deep copy the model if phase == 'val' and epoch_loss > best_loss: print('Val Dice better than Best Dice') best_loss = epoch_loss best_model_wts = copy.deepcopy(model.state_dict()) except: # save model save_model(epoch, best_model_wts, optimizer, scheduler, epoch_loss, PATH) exit(0) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val DICE: {:4f}'.format(best_loss)) # save model save_model(num_epochs, best_model_wts, optimizer, scheduler, epoch_loss, PATH)
def experiment1(model, optimizer, scheduler, dataloaders, dataset_sizes, num_epochs=25, viz = False): #model to CUDA model = model.to(device) #out dirs base_dir = Path.cwd() / 'outputs' / 'experiment1' output_tracking_dir = base_dir / 'output_tracking' logs_dir = base_dir / 'logs' model_dir = base_dir / 'model' logs_dir.mkdir(parents= True, exist_ok= True) model_dir.mkdir(parents= True, exist_ok=True) output_tracking_dir.mkdir(parents=True, exist_ok=True) since = time.time() PATH = str(model_dir / (model.name+'.pth')) epo = 1 if Path(PATH).is_file(): checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epo = checkpoint['epoch'] loss = checkpoint['loss'] scheduler.load_state_dict(checkpoint['scheduler_state_dict']) print('Resuming from epoch ' + str(epo) + ', LOSS: ', loss) best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 3.0 logs_ptr = open(str(logs_dir/ 'train_logs'), 'a') # pdb.set_trace() for epoch in range(epo, epo + num_epochs): epoch_str = 'Epoch {}/{}'.format(epoch, epo + num_epochs - 1) + '\n\n' print(epoch_str) logs_ptr.write(epoch_str) print('-' * 10) try: # Each epoch has a training and validation phase for phase in ['train', 'val']: if phase == 'train': model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode running_reg = 0.0 running_dice = 0.0 # Iterate over data. times = 0 for mini_batch, (inputs, label224, label28) in enumerate(dataloaders[phase]): inputs = inputs.to(device) # labels size is (batch_size, 1, 224, 224) label28 = label28.to(device) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): log_softmax_outputs28 = model(inputs) # shape of pred28 is (batch_size, 2, 28, 28) softmax_outputs28 = torch.exp(log_softmax_outputs28) output28_prob = get_prob_map28(softmax_outputs28) reg_loss = torch.mean( torch.sum(-torch.log(1.0 - torch.abs(output28_prob - label28)), dim=[1, 2, 3]) )/1000.0 dice_l = dice_loss(input=torch.round(output28_prob), target=torch.round(label28)) total_loss = reg_loss + 0.5*dice_l # backward + optimize only if in training phase if phase == 'train': total_loss.backward() optimizer.step() if phase == 'train': step_str = '{} Step {}- Loss: {:.4f}, Dice Loss: {:.4f}, Reg Loss: {:.4f}'\ .format(phase, mini_batch + 1,total_loss, dice_l, reg_loss) print(step_str) logs_ptr.write(step_str+'\n') if phase == 'val' and viz: output28_prob = output28_prob.cpu() label28 = label28.cpu() for item in range(label28.size(0)): expanded_output28_prob = expand_mask([[0, 0, 224, 224]], output28_prob[item].detach().numpy(), (224, 224)) expanded_label28 = expand_mask([[0, 0, 224, 224]], label28[item].detach().numpy(), (224, 224)) epoch_tracking_path = output_tracking_dir / str(epoch) if not epoch_tracking_path.is_dir(): epoch_tracking_path.mkdir(parents=True, exist_ok=False) actual_predicted(expanded_label28[0], expanded_output28_prob[0], str(epoch_tracking_path / (str(mini_batch*label28.size(0) +item) + '.jpg') ) ) # statistics # running_loss += step_loss.item() * inputs.size(0) running_dice += dice_l.item() * inputs.size(0) running_reg += reg_loss.item() * inputs.size(0) # times+=1 # if times==2: # break # end of an epoch # pdb.set_trace() # epoch_loss = running_loss / dataset_sizes[phase] epoch_dice_l = running_dice / dataset_sizes[phase] epoch_reg_loss = running_reg / dataset_sizes[phase] epoch_loss = epoch_dice_l + epoch_reg_loss if phase == 'train': scheduler.step() loss_str = '\n{} Epoch {}: TotalLoss: {:.4f} RegLoss: {:.4f} Dice Loss: {:.4f} \n'.format( phase, epoch, epoch_loss, epoch_reg_loss, epoch_dice_l) + '\n' print(loss_str) logs_ptr.write(loss_str + '\n') # deep copy the model if phase == 'val' and epoch_loss >= best_loss: print('Val Dice better than Best Dice') best_loss = epoch_loss best_model_wts = copy.deepcopy(model.state_dict()) except: # save model save_model(epoch, best_model_wts, optimizer, scheduler, epoch_loss, PATH) exit(0) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val DICE: {:4f}'.format(best_loss)) # save model save_model(num_epochs, best_model_wts, optimizer, scheduler, epoch_loss, PATH)
def main(): parser = ArgumentParser() parser.add_argument('-d', '--data_path', dest='data_path', type=str, default=None, help='path to the data') parser.add_argument('-e', '--epochs', dest='epochs', default=20, type=int, help='number of epochs') parser.add_argument('-b', '--batch_size', dest='batch_size', default=40, type=int, help='batch size') parser.add_argument('-s', '--image_size', dest='image_size', default=256, type=int, help='input image size') parser.add_argument('-lr', '--learning_rate', dest='lr', default=0.0001, type=float, help='learning rate') parser.add_argument('-wd', '--weight_decay', dest='weight_decay', default=5e-4, type=float, help='weight decay') parser.add_argument('-lrs', '--learning_rate_step', dest='lr_step', default=10, type=int, help='learning rate step') parser.add_argument('-lrg', '--learning_rate_gamma', dest='lr_gamma', default=0.5, type=float, help='learning rate gamma') parser.add_argument( '-m', '--model', dest='model', default='fpn', ) parser.add_argument('-w', '--weight_bce', default=0.5, type=float, help='weight BCE loss') parser.add_argument('-l', '--load', dest='load', default=False, help='load file model') parser.add_argument('-v', '--val_split', dest='val_split', default=0.7, help='train/val split') parser.add_argument('-o', '--output_dir', dest='output_dir', default='./output', help='dir to save log and models') args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) logger = get_logger(os.path.join(args.output_dir, 'train.log')) logger.info('Start training with params:') for arg, value in sorted(vars(args).items()): logger.info("Argument %s: %r", arg, value) # net = UNet() # TODO: to use move novel arch or/and more lightweight blocks (mobilenet) to enlarge the batch_size # net = smp.FPN('mobilenet_v2', encoder_weights='imagenet', classes=2) net = smp.FPN('se_resnet50', encoder_weights='imagenet', classes=2) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if args.load: net.load_state_dict(torch.load(args.load)) logger.info('Model type: {}'.format(net.__class__.__name__)) net.to(device) optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) criterion = lambda x, y: (args.weight_bce * nn.BCELoss()(x, y), (1. - args.weight_bce) * dice_loss(x, y)) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) \ if args.lr_step > 0 else None train_transforms = Compose([ Crop(min_size=1 - 1 / 3., min_ratio=1.0, max_ratio=1.0, p=0.5), Flip(p=0.05), RandomRotate(), Pad(max_size=0.6, p=0.25), Resize(size=(args.image_size, args.image_size), keep_aspect=True), ScaleToZeroOne(), ]) val_transforms = Compose([ Resize(size=(args.image_size, args.image_size)), ScaleToZeroOne(), ]) train_dataset = DetectionDataset(args.data_path, os.path.join(args.data_path, 'train_mask.json'), transforms=train_transforms) val_dataset = DetectionDataset(args.data_path, None, transforms=val_transforms) train_size = int(len(train_dataset) * args.val_split) val_dataset.image_names = train_dataset.image_names[train_size:] val_dataset.mask_names = train_dataset.mask_names[train_size:] train_dataset.image_names = train_dataset.image_names[:train_size] train_dataset.mask_names = train_dataset.mask_names[:train_size] train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=8, shuffle=True, drop_last=True) val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False, drop_last=False) logger.info('Number of batches of train/val=%d/%d', len(train_dataloader), len(val_dataloader)) try: train(net, optimizer, criterion, scheduler, train_dataloader, val_dataloader, logger=logger, args=args, device=device) except KeyboardInterrupt: torch.save( net.state_dict(), os.path.join(args.output_dir, f'{args.model}_INTERRUPTED.pth')) logger.info('Saved interrupt') sys.exit(0)
optimizer = optim.SGD(net_parallel.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.00004) iter_num = 0 while True: for i_batch, sampled_batch in enumerate(dataloader): volume_batch, label_batch = sampled_batch[ 'image'], sampled_batch['label'] volume_batch, label_batch = volume_batch.cuda( ), label_batch.cuda() output = net_parallel(volume_batch) output = F.sigmoid(output) loss = dice_loss(output, label_batch) optimizer.zero_grad() loss.backward() optimizer.step() iter_num = iter_num + 1 if iter_num % 10 == 0: print('iteration %d : loss : %f' % (iter_num, loss.item())) if iter_num % 5000 == 0: torch.save( net.state_dict(), os.path.join( snapshot_path, snapshot_prefix + '_iteration_' + str(iter_num) + '.pth')) if iter_num >= max_iterations:
def forward(self, batch_input, task=None): batch_output = {} # index = batch_input["idx"] self.stage = "finetune" views = batch_input["image"] device = views.device bs = views.size(0) self.batch_size = bs # road_map = batch_input["road"] final_features = self.image_network(views.flatten(0, 1)) _, c, h, w = final_features.shape views = final_features.view(bs, 6, c, h, w) batch_output["loss"] = 0 # print("views", views.shape) if self.gen_roadmap or self.gen_semantic_map or self.gen_object_map or ( self.detect_objects and "decoder" in self.blobs_strategy): fusion = self.fuse(views) # print("fusion", fusion.shape) if self.gen_roadmap or self.gen_semantic_map or self.gen_object_map: if "det" in self.model_type: # print("here") # if self.dense_fuse: # fusion = self.reshape(fusion).view(-1,32,16,16) # print("reshape", fusion.shape) mapped_image = self.decoder_network(fusion) #fusion) # if self.training: # print(mapped_image.shape) if self.gen_roadmap: batch_output["road_map"] = F.sigmoid(mapped_image) else: batch_output["sem_map"] = F.softmax(mapped_image, dim=1) if self.loss_type == "dice": if self.args.gen_road_map: batch_output["recon_loss"] = dice_loss( batch_input["road"].type(torch.LongTensor), mapped_image) else: batch_output["recon_loss"] = dice_loss( batch_input["sem_map"].max(dim=1)[1].type( torch.LongTensor), mapped_image) elif self.loss_type == 'bce': if self.gen_roadmap: batch_output["recon_loss"] = self.criterion( mapped_image, batch_input["road"]) else: batch_output["recon_loss"] = self.criterion( mapped_image, batch_input["sem_map"].max(dim=1)[1]) else: if self.args.gen_road_map: batch_output["recon_loss"] = self.criterion( batch_output["road_map"], batch_input["road"]) else: batch_output["recon_loss"] = self.criterion( batch_output["sem_map"], batch_input["sem_map"]) if self.gen_roadmap: batch_output["ts_road_map"] = compute_ts_road_map( batch_output["road_map"], batch_input["road"]) else: batch_output["ts_road_map"] = (batch_output["sem_map"].max( dim=1)[1] == batch_input["sem_map"].max( dim=1)[1]).float().mean() batch_output["ts"] = batch_output["ts_road_map"] batch_output["loss"] += batch_output["recon_loss"] # else: # return nn.Sigmoid(mapped_image) else: if self.conv_fuse: fusion = self.avg_pool_refine( self.avg_pool(fusion).view(-1, self.d_model)) mu_logvar = self.z_project(fusion).view(bs, 2, self.latent_dim) mu = mu_logvar[:, 0, :] logvar = mu_logvar[:, 1, :] z = self.reparameterize(mu, logvar) z = self.z_refine(z) z = self.z_reshape(z).view(bs, 32, 16, 16) generated_image = self.decoder_network(z) if self.gen_roadmap: batch_output["road_map"] = nn.Sigmoid(generated_image) else: batch_output["sem_map"] = nn.Softmax(generated_image, dim=1) if self.loss_type == "dice": batch_output["recon_loss"] = dice_loss( batch_input["road"], batch_output["road_map"]) elif self.loss_type == 'bce': if self.gen_roadmap: batch_output["recon_loss"] = self.criterion( generated_image, batch_input["road"]) else: batch_output["recon_loss"] = self.criterion( generated_image, batch_input["sem_map"].max(dim=1)[1]) else: batch_output["recon_loss"] = self.criterion( batch_output["road_map"], batch_input["road"]) if self.gen_roadmap: batch_output["ts_road_map"] = compute_ts_road_map( batch_output["road_map"], batch_input["road"]) else: batch_output["ts_road_map"] = ( batch_output["road_map"].max( dim=1)[1] == batch_input["sem_map"].max( dim=1)[1]).float().mean() batch_output["KLD_loss"] = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) batch_output["ts"] = batch_output["ts_road_map"] batch_output["loss"] += batch_output[ "recon_loss"] + batch_output["KLD_loss"] if self.detect_objects: if "decoder" in self.blobs_strategy: if "var" in self.model_type: batch_output = self.obj_detection_model( z, batch_input, batch_output) else: batch_output = self.obj_detection_model( fusion, batch_input, batch_output, fusion) else: batch_output = self.obj_detection_model( batch_input["image"], batch_input, batch_output) return batch_output
def forward(self, batch_input, task="none"): batch_output = {} # index = batch_input["idx"] self.stage = "finetune" views = batch_input["image"] device = views.device bs = views.size(0) self.batch_size = bs # road_map = batch_input["road"] gen_latent_features = self.image_network(views.flatten(0, 1)) _, c, h, w = gen_latent_features.shape views = gen_latent_features.view(bs, 6, c, h, w) batch_output["loss"] = 0 # print("views", views.shape) if "cond" in self.args.finetune_obj: z = torch.randn(bs, self.args.latent_dim).to(device) fusion = self.fuse(views, z) else: fusion = self.fuse(views) # print("fusion", fusion.shape) # print("here") # if self.dense_fuse: # fusion = self.reshape(fusion).view(-1,32,16,16) # print("reshape", fusion.shape) gen_image = self.decoder_network(fusion) #fusion) # real_disc_inp = batch_input["road"] # fake_disc_inp = gen_image.detach() # if "patch" in self.args.disc_type: # b,c,h,w = real_disc_inp.shape # # real_disc_inp = real_disc_inp.view(b,-1) # # fake_disc_inp = fake_disc_inp.view(b,-1) # zeros = torch.zeros(bs,1,16,16).to(device) # ones = torch.ones(bs,1,16,16).to(device) # else: # zeros = torch.zeros(bs,1).to(device) # ones = torch.ones(bs,1).to(device) # real_disc_op = self.discriminator(real_disc_inp) # batch_output["real_dloss"] = self.criterion(real_disc_op,ones) # fake_disc_op = self.discriminator(fake_disc_inp) # batch_output["fake_dloss"] = self.criterion(fake_disc_op,zeros) # batch_output["Dloss"] = batch_output["real_dloss"] + batch_output["fake_dloss"] if self.args.gen_road_map: batch_output["road_map"] = F.sigmoid(gen_image) else: batch_output["road_map"] = F.softmax(gen_image, dim=1) if self.args.gen_road_map: batch_output["ts_road_map"] = compute_ts_road_map( batch_output["road_map"], batch_input["road"]) else: batch_output["ts_road_map"] = (batch_output["road_map"].max( dim=1)[1] == batch_input["sem_map"].max( dim=1)[1]).float().mean() batch_output["ts"] = batch_output["ts_road_map"] # batch_output["GDiscloss"] = self.criterion(fake_disc_op,ones) if self.args.road_map_loss == "dice": if self.args.gen_road_map: batch_output["GSupLoss"] = dice_loss( batch_input["road"].type(torch.LongTensor), gen_image) else: batch_output["GSupLoss"] = dice_loss( batch_input["sem_map"].max(dim=1)[1].type( torch.LongTensor), gen_image) # batch_output["GSupLoss"] = dice_loss(batch_input["road"], batch_output["road_map"]) else: if self.args.gen_road_map: batch_output["GSupLoss"] = self.criterion( gen_image, batch_input["road"]) else: batch_output["GSupLoss"] = self.criterion( gen_image, batch_input["sem_map"].max(dim=1)[1]) # else: # batch_output["GSupLoss"] = self.criterion(batch_output["road_map"], batch_input["road"]) # batch_output["GSupLoss"] = self.criterion(batch_output["road_map"],batch_input["road"]) # batch_output["Gloss"] = batch_output["GDiscloss"] + batch_output["GSupLoss"] # batch_output["loss"] = batch_output["Dloss"] + batch_output["Gloss"] # if self.training: # batch_output["recon_loss"] = self.criterion(mapped_image, road_map) # batch_output["road_map"] = nn.Sigmoid(mapped_image) # batch_output["ts_road_map"] = compute_ts_road_map(batch_output["road_map"],road_map) # batch_output["loss"] += batch_output["recon_loss"] # else: # return nn.Sigmoid(mapped_image) # if self.detect_objects: # if "decoder" in self.blobs_strategy: # if "var" in self.model_type: # batch_output = self.obj_detection_model(z,batch_input,batch_output) # else: # batch_output = self.obj_detection_model(fusion,batch_input,batch_output,fusion) # else: # batch_output = self.obj_detection_model(batch_input["image"],batch_input,batch_output) return batch_output
def main(): parser = ArgumentParser() parser.add_argument('-d', '--data_path', dest='data_path', type=str, default=None, help='path to the data') parser.add_argument('-e', '--epochs', dest='epochs', default=20, type=int, help='number of epochs') parser.add_argument('-b', '--batch_size', dest='batch_size', default=40, type=int, help='batch size') parser.add_argument('-s', '--image_size', dest='image_size', default=256, type=int, help='input image size') parser.add_argument('-lr', '--learning_rate', dest='lr', default=0.0001, type=float, help='learning rate') parser.add_argument('-wd', '--weight_decay', dest='weight_decay', default=5e-4, type=float, help='weight decay') parser.add_argument('-lrs', '--learning_rate_step', dest='lr_step', default=10, type=int, help='learning rate step') parser.add_argument('-lrg', '--learning_rate_gamma', dest='lr_gamma', default=0.5, type=float, help='learning rate gamma') parser.add_argument('-m', '--model', dest='model', default='unet', choices=('unet', )) parser.add_argument('-w', '--weight_bce', default=0.5, type=float, help='weight BCE loss') parser.add_argument('-l', '--load', dest='load', default=False, help='load file model') parser.add_argument('-v', '--val_split', dest='val_split', default=0.8, help='train/val split') parser.add_argument('-o', '--output_dir', dest='output_dir', default='/tmp/logs/', help='dir to save log and models') args = parser.parse_args() # os.makedirs(args.output_dir, exist_ok=True) logger = get_logger(os.path.join(args.output_dir, 'train.log')) logger.info('Start training with params:') for arg, value in sorted(vars(args).items()): logger.info("Argument %s: %r", arg, value) # net = UNet( ) # TODO: to use move novel arch or/and more lightweight blocks (mobilenet) to enlarge the batch_size # TODO: img_size=256 is rather mediocre, try to optimize network for at least 512 logger.info('Model type: {}'.format(net.__class__.__name__)) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if args.load: net.load_state_dict(torch.load(args.load)) net.to(device) # net = nn.DataParallel(net) optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) # TODO: loss experimentation, fight class imbalance, there're many ways you can tackle this challenge criterion = lambda x, y: (args.weight_bce * nn.BCELoss()(x, y), (1. - args.weight_bce) * dice_loss(x, y)) # TODO: you can always try on plateau scheduler as a default option scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) \ if args.lr_step > 0 else None # dataset # TODO: to work on transformations a lot, look at albumentations package for inspiration train_transforms = Compose([ Crop(min_size=1 - 1 / 3., min_ratio=1.0, max_ratio=1.0, p=0.5), Flip(p=0.05), Pad(max_size=0.6, p=0.25), Resize(size=(args.image_size, args.image_size), keep_aspect=True) ]) # TODO: don't forget to work class imbalance and data cleansing val_transforms = Resize(size=(args.image_size, args.image_size)) train_dataset = DetectionDataset(args.data_path, os.path.join(args.data_path, 'train_mask.json'), transforms=train_transforms) val_dataset = DetectionDataset(args.data_path, None, transforms=val_transforms) # split dataset into train/val, don't try to do this at home ;) train_size = int(len(train_dataset) * args.val_split) val_dataset.image_names = train_dataset.image_names[train_size:] val_dataset.mask_names = train_dataset.mask_names[train_size:] train_dataset.image_names = train_dataset.image_names[:train_size] train_dataset.mask_names = train_dataset.mask_names[:train_size] # TODO: always work with the data: cleaning, sampling train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=8, shuffle=True, drop_last=True) val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, num_workers=4, shuffle=False, drop_last=False) logger.info('Length of train/val=%d/%d', len(train_dataset), len(val_dataset)) logger.info('Number of batches of train/val=%d/%d', len(train_dataloader), len(val_dataloader)) try: train(net, optimizer, criterion, scheduler, train_dataloader, val_dataloader, logger=logger, args=args, device=device) except KeyboardInterrupt: torch.save(net.state_dict(), os.path.join(args.output_dir, 'INTERRUPTED.pth')) logger.info('Saved interrupt') sys.exit(0)