def train_single(model, batch, loader_val, optimizer, old_loss, old_f1, device, out_dir): for _ in range(5): _, inputs, labels = batch inputs = inputs.to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() with torch.set_grad_enabled(True): outputs = model(inputs) loss = dice_loss(outputs, labels) # backward + optimize only if in training phase loss.backward() optimizer.step() new_loss, new_f1 = get_loss_score(model, loader_val, device) _, inputs, _ = batch inputs = inputs.to(device) pred_batch = model(inputs) torch.save(pred_batch, out_dir) return out_dir, (old_loss - new_loss).cpu().item(), (new_f1 - old_f1).cpu().item()
def train(teacher, optimizer, train_loader): print(' --- teacher training') teacher.train().cuda() criterion = nn.BCEWithLogitsLoss() ll = [] for i, (img, gt) in enumerate(train_loader): print('i', i) if torch.cuda.is_available(): img, gt = img.cuda(), gt.cuda() img, gt = Variable(img), Variable(gt) output = teacher(img) output = output.clamp(min=0, max=1) gt = gt.clamp(min=0, max=1) loss = dice_loss(output, gt) ll.append(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() mean_dice = np.mean(ll) print("Average loss over this epoch:\n\tDice:{}".format(mean_dice))
def train(images, labels): with tf.GradientTape() as tape: output = model(images, training=True) loss = dice_loss(labels, output) metric = metrics(labels, output) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) return loss, metric
def calc_loss(pred, target, metrics, bce_weight=0.5): bce = torch.nn.functional.binary_cross_entropy_with_logits(pred, target) pred = torch.sigmoid(pred) dice = dice_loss(pred, target) pred_binary = normalise_mask(pred.detach().cpu().numpy()) iou = intersection_over_union(target.detach().cpu().numpy(), pred_binary) loss = bce * bce_weight + dice * (1 - bce_weight) metrics['bce'] += bce.data.cpu().numpy() * target.size(0) metrics['dice'] += dice.data.cpu().numpy() * target.size(0) metrics['iou'] += iou * target.size(0) metrics['loss'] += loss.data.cpu().numpy() * target.size(0) return loss
def get_loss_score(model, loader, device): trainig = model.training model.eval() loss_sum = 0 acc_sum = 0 num_pts = 0 for _, inputs, labels in loader: inputs = inputs.to(device) labels = labels.to(device) with torch.set_grad_enabled(False): outputs = model(inputs) loss = dice_loss(outputs, labels) acc = calc_f1(outputs, labels) loss_sum += loss * inputs.size(0) acc_sum += acc * inputs.size(0) num_pts += inputs.size(0) model.train() if model.training else model.eval() return loss_sum / num_pts, acc_sum / num_pts
def evaluate(teacher, val_loader): teacher.eval().cuda() criterion = nn.BCEWithLogitsLoss() ll = [] with torch.no_grad(): for i, (img, gt) in enumerate(val_loader): if torch.cuda.is_available(): img, gt = img.cuda(), gt.cuda() img, gt = Variable(img), Variable(gt) output = teacher(img) output = output.clamp(min=0, max=1) gt = gt.clamp(min=0, max=1) loss = dice_loss(output, gt) ll.append(loss.item()) mean_dice = np.mean(ll) print('Eval metrics:\n\tAverabe Dice loss:{}'.format(mean_dice))
def evaluate_kd(student, val_loader): print('-------Evaluate student-------') student.eval().cuda() #criterion = torch.nn.BCEWithLogitsLoss() loss_summ = [] with torch.no_grad(): for i, (img, gt) in enumerate(val_loader): if torch.cuda.is_available(): img, gt = img.cuda(), gt.cuda() img, gt = Variable(img), Variable(gt) output = student(img) output = output.clamp(min=0, max=1) loss = dice_loss(output, gt) loss_summ.append(loss.item()) mean_loss = np.mean(loss_summ) print('- Eval metrics:\n\tAverage Dice loss:{}'.format(mean_loss)) return mean_loss
def train_student(student, teacher_outputs, optimizer, train_loader): print('-------Train student-------') #called once for each epoch student.train().cuda() summ = [] for i, (img, gt) in enumerate(train_loader): teacher_output = teacher_outputs[i] if torch.cuda.is_available(): img, gt = img.cuda(), gt.cuda() teacher_output = teacher_output.cuda() img, gt = Variable(img), Variable(gt) teacher_output = Variable(teacher_output) output = student(img) #TODO: loss is wrong loss = loss_fn_kd(output, teacher_output, gt) # clear previous gradients, compute gradients of all variables wrt loss optimizer.zero_grad() loss.backward() # performs updates using calculated gradients optimizer.step() if i % summary_steps == 0: #do i need to move it to CPU? metric = dice_loss(output, gt) summary = {'metric': metric.item(), 'loss': loss.item()} summ.append(summary) #print('Average loss over this epoch: ' + np.mean(loss_avg)) mean_dice_coeff = np.mean([x['metric'] for x in summ]) mean_loss = np.mean([x['loss'] for x in summ]) print('- Train metrics:\n' + '\tMetric:{}\n\tLoss:{}'.format(mean_dice_coeff, mean_loss))
return np.median(freq) / freq weights = compute_class_weights(Class) # Define training functions print "Defining and compiling training functions" prediction = lasagne.layers.get_output(simple_net_output[0]) # deb_pred = theano.function([input_var], prediction) # Loss function #loss = weighted_crossentropy(prediction, target_var, weight_vector) loss = dice_loss(prediction, target_var) loss = loss.mean() # Add regularization if weight_decay > 0: weightsl2 = regularize_network_params(simple_net_output, lasagne.regularization.l2) loss += weight_decay * weightsl2 # Add penalty to enforce the same number of transitions: if penalty_transitions > 0: true_prediction = T.reshape(target_var, (-1, 200)) prediction_reshape = T.reshape(prediction, (-1, 200, 2)) penalty_loss = abs(prediction_reshape[:, :, 1] - true_prediction).sum(axis=1) loss += penalty_transitions * penalty_loss.mean()
def train_model(model, dataloaders, policy_learner, optimizer, scheduler, num_epochs, device, writer, n_images=None): loader = {'val': dataloaders['val']} # best_model_wts = copy.deepcopy(model.state_dict()) best_loss = 1e10 if n_images is None: n_images = {'train': 0, 'val': 0} for epoch in range(num_epochs): loader['train'] = policy_learner() print('Epoch {}/{}'.format(epoch, num_epochs - 1)) print('-' * 10) since = time.time() # Each epoch has a training and validation phase for phase in ['train', 'val']: # print('+++++++++ len loader', len(loader[phase])) if phase == 'train': if scheduler: scheduler.step() for param_group in optimizer.param_groups: print("LR", param_group['lr']) model.train() # Set model to training mode else: model.eval() # Set model to evaluate mode metrics = defaultdict(float) epoch_samples = 0 for enum_id, (idxs, inputs, labels) in tqdm(enumerate(loader[phase]), total=len(loader[phase])): inputs = inputs.to(device) labels = labels.to(device) # if phase == 'train' and enum_id < 3: # for idx in idxs: # torch.save(torch.tensor(1), # f'tmp/trash/{policy_learner.__class__.__name__}_{epoch}_{enum_id}__{idx}' # ) # zero the parameter gradients optimizer.zero_grad() # forward # track history if only in train with torch.set_grad_enabled(phase == 'train'): outputs = model(inputs) # loss, loss_sum, loss_bce, loss_dice = calc_loss(outputs, labels, 0) loss = dice_loss(outputs, labels) acc_f1 = calc_f1(outputs, labels) # acc_iou = calc_IOU(outputs, labels) # backward + optimize only if in training phase if phase == 'train': loss.backward() plot_grad_flow(epoch, enum_id, model.named_parameters()) optimizer.step() # statistics epoch_samples += inputs.size(0) n_images[phase] += inputs.size(0) writer.add_scalar(f'{phase}/loss', loss.data.cpu().numpy(), n_images[phase]) # writer.add_scalar(f'{phase}/bce', loss_bce, n_images[phase]) # writer.add_scalar(f'{phase}/dice', loss_dice, n_images[phase]) metrics['loss'] += loss * inputs.size(0) metrics['f1'] += acc_f1 * inputs.size(0) # metrics['iou'] += acc_iou * inputs.size(0) print_metrics(writer, metrics, epoch_samples, phase) epoch_loss = metrics['loss'] / epoch_samples writer.add_scalar(f'{phase}/epoch_loss', epoch_loss, epoch) epoch_f1 = metrics['f1'] / epoch_samples writer.add_scalar(f'{phase}/epoch_F1', epoch_f1, epoch) # epoch_iou = metrics['iou'] / epoch_samples # writer.add_scalar(f'{phase}/epoch_IOU', epoch_iou, epoch) # # deep copy the model # if phase == 'val' and epoch_loss < best_loss: # print("saving best model") # best_loss = epoch_loss # best_model_wts = copy.deepcopy(model.state_dict()) time_elapsed = time.time() - since print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) print('Best val loss: {:4f}'.format(best_loss)) # load best model weights # model.load_state_dict(best_model_wts) return model, n_images
def general_loss(student_output, gt): #use torch.nn.CrossENtropyLoss() loss = dice_loss(student_output, gt) return loss
def train(dataset, segm_net, learning_rate=0.005, lr_anneal=1.0, weight_decay=1e-4, num_epochs=500, max_patience=100, optimizer='rmsprop', training_loss=['squared_error'], batch_size=[10, 1, 1], ae_h=False, dae_dict_updates={}, data_augmentation={}, savepath=None, loadpath=None, resume=False, train_from_0_255=False, lmb=1, full_im_ft=False): # # Update DAE parameters # dae_dict = {'kind': 'fcn8', 'dropout': 0.0, 'skip': True, 'unpool_type': 'standard', 'n_filters': 64, 'conv_before_pool': 1, 'additional_pool': 0, 'concat_h': ['input'], 'noise': 0.0, 'from_gt': True, 'temperature': 1.0, 'path_weights': '', 'layer': 'probs_dimshuffle', 'exp_name': '', 'bn': 0} dae_dict.update(dae_dict_updates) # # Prepare load/save directories # exp_name = build_experiment_name(segm_net, training_loss=training_loss, data_aug=bool(data_augmentation), learning_rate=learning_rate, lr_anneal=lr_anneal, weight_decay=weight_decay, optimizer=optimizer, ae_h=ae_h, **dae_dict) if savepath is None: raise ValueError('A saving directory must be specified') loadpath_init = os.path.join(loadpath, dataset, exp_name) exp_name += '_ft' if full_im_ft else '' loadpath = os.path.join(loadpath, dataset, exp_name) savepath = os.path.join(savepath, dataset, exp_name) if not os.path.exists(savepath): os.makedirs(savepath) else: print('\033[93m The following folder already exists {}. ' 'It will be overwritten in a few seconds...\033[0m'.format( savepath)) print('Saving directory : ' + savepath) with open(os.path.join(savepath, "config.txt"), "w") as f: for key, value in locals().items(): f.write('{} = {}\n'.format(key, value)) # # Define symbolic variables # input_x_var = T.tensor4('input_x_var') # tensor for input image batch input_mask_var = T.tensor4('input_mask_var') # tensor for segmentation bach (input dae) input_concat_h_vars = [T.tensor4()] * len(dae_dict['concat_h']) # tensor for hidden repr batch (input dae) target_var = T.tensor4('target_var') # tensor for target batch # learning_rate = learning_rate*0.1 if full_im_ft else learning_rate # learning_rate = 0.01 print learning_rate lr = theano.shared(np.float32(learning_rate), 'learning_rate') # # Build dataset iterator # train_iter, val_iter, _ = load_data(dataset, data_augmentation, one_hot=True, batch_size=batch_size, return_0_255=train_from_0_255, ) n_batches_train = train_iter.nbatches n_batches_val = val_iter.nbatches n_classes = train_iter.non_void_nclasses void_labels = train_iter.void_labels nb_in_channels = train_iter.data_shape[0] void = n_classes if any(void_labels) else n_classes+1 # # Build networks # # Check that model and dataset get along print 'Checking options' assert (segm_net == 'fcn8' and dataset == 'camvid') or \ (segm_net == 'densenet' and dataset == 'camvid') assert (data_augmentation['crop_size'] == None and full_im_ft) or not full_im_ft # Build segmentation network print 'Building segmentation network' if segm_net == 'fcn8': layer_out = copy.copy(dae_dict['concat_h']) layer_out += [copy.copy(dae_dict['layer'])] if not dae_dict['from_gt'] else [] fcn = buildFCN8(nb_in_channels, input_x_var, n_classes=n_classes, void_labels=void_labels, path_weights=WEIGHTS_PATH+dataset+'/fcn8_model.npz', load_weights=True, layer=layer_out) padding = 100 elif segm_net == 'densenet': fcn = build_fcdensenet(input_x_var, nb_in_channels=nb_in_channels, n_classes=n_classes, layer=dae_dict['concat_h'], from_gt=dae_dict['from_gt']) padding = 0 elif segm_net == 'fcn_fcresnet': raise NotImplementedError else: raise ValueError # Build DAE network print 'Building DAE network' if ae_h and dae_dict['kind'] != 'standard': raise ValueError('Plug&Play not implemented for ' + dae_dict['kind']) if ae_h and 'pool' not in dae_dict['concat_h'][-1]: raise ValueError('Plug&Play version needs concat_h to be different than input') ae_h = ae_h and 'pool' in dae_dict['concat_h'][-1] if dae_dict['kind'] == 'standard': nb_features_to_concat=fcn[0].output_shape[1] dae = buildDAE(input_concat_h_vars, input_mask_var, n_classes, nb_features_to_concat=nb_features_to_concat, padding=padding, trainable=True, void_labels=void_labels, load_weights=resume or full_im_ft, path_weights=loadpath_init, model_name='dae_model_best.npz', out_nonlin=softmax, concat_h=dae_dict['concat_h'], noise=dae_dict['noise'], n_filters=dae_dict['n_filters'], conv_before_pool=dae_dict['conv_before_pool'], additional_pool=dae_dict['additional_pool'], dropout=dae_dict['dropout'], skip=dae_dict['skip'], unpool_type=dae_dict['unpool_type'], bn=dae_dict['bn'], ae_h=ae_h) elif dae_dict['kind'] == 'fcn8': dae = buildFCN8_DAE(input_concat_h_vars, input_mask_var, n_classes, nb_in_channels=n_classes, trainable=True, load_weights=resume, pretrained=True, pascal=True, concat_h=dae_dict['concat_h'], noise=dae_dict['noise'], dropout=dae_dict['dropout'], path_weights=os.path.join('/'.join(loadpath_init.split('/')[:-1]), dae_dict['path_weights']), model_name='dae_model_best.npz') elif dae_dict['kind'] == 'contextmod': dae = buildDAE_contextmod(input_concat_h_vars, input_mask_var, n_classes, path_weights=loadpath_init, model_name='dae_model.npz', trainable=True, load_weights=resume, out_nonlin=softmax, noise=dae_dict['noise'], concat_h=dae_dict['concat_h']) else: raise ValueError('Unknown dae kind') # # Define and compile theano functions # # training functions print "Defining and compiling training functions" # fcn prediction fcn_prediction = lasagne.layers.get_output(fcn, deterministic=True, batch_norm_use_averages=False) # select prediction layers (pooling and upsampling layers) dae_all_lays = lasagne.layers.get_all_layers(dae) if dae_dict['kind'] != 'contextmod': dae_lays = [l for l in dae_all_lays if isinstance(l, Pool2DLayer) or isinstance(l, CroppingLayer) or isinstance(l, ElemwiseSumLayer) or l == dae_all_lays[-1]] # dae_lays = dae_lays[::2] else: dae_lays = [l for l in dae_all_lays if isinstance(l, DilatedConv2DLayer) or l == dae_all_lays[-1]] if ae_h: h_ae_idx = [i for i, el in enumerate(dae_lays) if el.name == 'h_to_recon'][0] h_hat_idx = [i for i, el in enumerate(dae_lays) if el.name == 'h_hat'][0] # predictions dae_prediction_all = lasagne.layers.get_output(dae_lays, batch_norm_use_averages=False) dae_prediction = dae_prediction_all[-1] dae_prediction_h = dae_prediction_all[:-1] test_dae_prediction_all = lasagne.layers.get_output(dae_lays, deterministic=True, batch_norm_use_averages=False) test_dae_prediction = test_dae_prediction_all[-1] test_dae_prediction_h = test_dae_prediction_all[:-1] # fetch h and h_hat if needed if ae_h: h = dae_prediction_all[h_ae_idx] h_hat = dae_prediction_all[h_hat_idx] h_test = test_dae_prediction_all[h_ae_idx] h_hat_test = test_dae_prediction_all[h_hat_idx] # loss loss = 0 test_loss = 0 # Convert DAE prediction to 2D dae_prediction_2D = dae_prediction.dimshuffle((0, 2, 3, 1)) sh = dae_prediction_2D.shape dae_prediction_2D = dae_prediction_2D.reshape((T.prod(sh[:3]), sh[3])) test_dae_prediction_2D = test_dae_prediction.dimshuffle((0, 2, 3, 1)) sh = test_dae_prediction_2D.shape test_dae_prediction_2D = test_dae_prediction_2D.reshape((T.prod(sh[:3]), sh[3])) # Convert target to 2D target_var_2D = target_var.dimshuffle((0, 2, 3, 1)) sh = target_var_2D.shape target_var_2D = target_var_2D.reshape((T.prod(sh[:3]), sh[3])) if 'crossentropy' in training_loss: # Compute loss loss += crossentropy(dae_prediction_2D, target_var_2D, void_labels, one_hot=True) test_loss += crossentropy(test_dae_prediction_2D, target_var_2D, void_labels, one_hot=True) if 'dice' in training_loss: loss += dice_loss(dae_prediction, target_var, void_labels) test_loss += dice_loss(test_dae_prediction, target_var, void_labels) test_mse_loss = squared_error(test_dae_prediction, target_var, void) if 'squared_error' in training_loss: mse_loss = squared_error(dae_prediction, target_var, void) loss += lmb*mse_loss test_loss += lmb*test_mse_loss # Add intermediate losses if 'squared_error_h' in training_loss: # extract input layers and create dictionary dae_input_lays = [l for l in dae_all_lays if isinstance(l, InputLayer)] inputs = {dae_input_lays[0]: target_var[:, :void, :, :], dae_input_lays[-1]:target_var[:, :void, :, :]} for idx, val in enumerate(input_concat_h_vars): inputs[dae_input_lays[idx+1]] = val test_dae_prediction_all_gt = lasagne.layers.get_output(dae_lays, inputs=inputs, deterministic=True, batch_norm_use_averages=False) test_dae_prediction_h_gt = test_dae_prediction_all_gt[:-1] loss += squared_error_h(dae_prediction_h, test_dae_prediction_h_gt) test_loss += squared_error_h(test_dae_prediction_h, test_dae_prediction_h_gt) # compute jaccard jacc = jaccard(dae_prediction_2D, target_var_2D, n_classes, one_hot=True) test_jacc = jaccard(test_dae_prediction_2D, target_var_2D, n_classes, one_hot=True) # if reconstructing h add the corresponding loss terms if ae_h: loss += squared_error_L(h, h_hat).mean() test_loss += squared_error_L(h_test, h_hat_test).mean() # network parameters params = lasagne.layers.get_all_params(dae, trainable=True) # optimizer if optimizer == 'rmsprop': updates = lasagne.updates.rmsprop(loss, params, learning_rate=lr) elif optimizer == 'adam': updates = lasagne.updates.adam(loss, params, learning_rate=lr) else: raise ValueError('Unknown optimizer') # functions train_fn = theano.function(input_concat_h_vars + [input_mask_var, target_var], loss, updates=updates) fcn_fn = theano.function([input_x_var], fcn_prediction) val_fn = theano.function(input_concat_h_vars + [input_mask_var, target_var], [test_loss, test_jacc, test_mse_loss]) err_train = [] err_valid = [] jacc_val_arr = [] mse_val_arr = [] patience = 0 # # Train # # Training main loop print "Start training" for epoch in range(num_epochs): # Single epoch training and validation start_time = time.time() cost_train_tot = 0 # Train for i in range(n_batches_train): # Get minibatch X_train_batch, L_train_batch = train_iter.next() L_train_batch = L_train_batch.astype(_FLOATX) #####uncomment if you want to control the feasability of pooling#### # max_n_possible_pool = np.floor(np.log2(np.array(X_train_batch.shape[2:]).min())) # # check if we don't ask for more poolings than possible # assert n_pool+additional_pool < max_n_possible_pool #################################################################### # h prediction H_pred_batch = fcn_fn(X_train_batch) if dae_dict['from_gt']: Y_pred_batch = L_train_batch[:, :void, :, :] else: Y_pred_batch = H_pred_batch[-1] H_pred_batch = H_pred_batch[:-1] # Training step cost_train = train_fn(*(H_pred_batch + [Y_pred_batch, L_train_batch])) cost_train_tot += cost_train err_train += [cost_train_tot / n_batches_train] # Validation cost_val_tot = 0 jacc_val_tot = 0 mse_val_tot = 0 for i in range(n_batches_val): # Get minibatch X_val_batch, L_val_batch = val_iter.next() L_val_batch = L_val_batch.astype(_FLOATX) # h prediction H_pred_batch = fcn_fn(X_val_batch) if dae_dict['from_gt']: Y_pred_batch = L_val_batch[:, :void, :, :] else: Y_pred_batch = H_pred_batch[-1] H_pred_batch = H_pred_batch[:-1] # Validation step cost_val, jacc_val, mse_val = val_fn(*(H_pred_batch + [Y_pred_batch, L_val_batch])) cost_val_tot += cost_val jacc_val_tot += jacc_val mse_val_tot += mse_val err_valid += [cost_val_tot / n_batches_val] jacc_val_arr += [np.mean(jacc_val_tot[0, :] / jacc_val_tot[1, :])] mse_val_arr += [mse_val_tot / n_batches_val] out_str = "EPOCH %i: Avg epoch training cost train %f, cost val %f," + \ " jacc val %f, mse val % f took %f s" out_str = out_str % (epoch, err_train[epoch], err_valid[epoch], jacc_val_arr[epoch], mse_val_arr[epoch], time.time() - start_time) print out_str with open(os.path.join(savepath, "output.log"), "a") as f: f.write(out_str + "\n") # update learning rate lr.set_value(float(lr.get_value() * lr_anneal)) # Early stopping and saving stuff if epoch == 0: best_err_val = err_valid[epoch] best_jacc_val = jacc_val_arr[epoch] best_mse_val = mse_val_arr[epoch] elif epoch > 0 and err_valid[epoch] < best_err_val: best_err_val = err_valid[epoch] best_jacc_val = jacc_val_arr[epoch] best_mse_val = mse_val_arr[epoch] patience = 0 np.savez(os.path.join(savepath, 'dae_model_best.npz'), *lasagne.layers.get_all_param_values(dae)) np.savez(os.path.join(savepath, 'dae_errors_best.npz'), err_train, err_valid, jacc_val_arr, mse_val_arr) else: patience += 1 np.savez(os.path.join(savepath, 'dae_model_last.npz'), *lasagne.layers.get_all_param_values(dae)) np.savez(os.path.join(savepath, 'dae_errors_last.npz'), err_train, err_valid, jacc_val_arr, mse_val_arr) # Finish training if patience has expired or max nber of epochs # reached if patience == max_patience or epoch == num_epochs - 1: # Copy files to loadpath if savepath != loadpath: print('Copying model and other training files to {}'.format( loadpath)) copy_tree(savepath, loadpath) # End print(' Training Done !') return
def validate(images, labels): output = model(images, training=False) loss = dice_loss(labels, output) metric = metrics(labels, output) return loss, metric
def train( epochs: int, models_dir: Path, x_cities: List[CityData], y_city: List[CityData], mask_dir: Path, ): model = UNet11().cuda() optimizer = Adam(model.parameters(), lr=3e-4) scheduler = ReduceLROnPlateau(optimizer, patience=4, factor=0.25) min_loss = sys.maxsize criterion = nn.BCEWithLogitsLoss() train_data = DataLoader(TrainDataset(x_cities, mask_dir), batch_size=4, num_workers=4, shuffle=True) test_data = DataLoader(TestDataset(y_city, mask_dir), batch_size=6, num_workers=4) for epoch in range(epochs): print(f'Epoch {epoch}, lr {optimizer.param_groups[0]["lr"]}') print(f" Training") losses = [] ious = [] jaccs = [] batch_iterator = enumerate(train_data) model = model.train().cuda() for i, (x, y) in tqdm(batch_iterator): optimizer.zero_grad() x = x.cuda() y = y.cuda() y_real = y.view(-1).float() y_pred = model(x) y_pred_probs = torch.sigmoid(y_pred).view(-1) loss = 0.75 * criterion(y_pred.view( -1), y_real) + 0.25 * dice_loss(y_pred_probs, y_real) iou_ = iou(y_pred_probs.float(), y_real.byte()) jacc_ = jaccard(y_pred_probs.float(), y_real) ious.append(iou_.item()) losses.append(loss.item()) jaccs.append(jacc_.item()) loss.backward() optimizer.step() print( f"Loss: {np.mean(losses)}, IOU: {np.mean(ious)}, jacc: {np.mean(jaccs)}" ) model = model.eval().cuda() losses = [] ious = [] jaccs = [] with torch.no_grad(): batch_iterator = enumerate(test_data) for i, (x, y) in tqdm(batch_iterator): x = x.cuda() y = y.cuda() y_real = y.view(-1).float() y_pred = model(x) y_pred_probs = torch.sigmoid(y_pred).view(-1) loss = 0.75 * criterion(y_pred.view( -1), y_real) + 0.25 * dice_loss(y_pred_probs, y_real) iou_ = iou(y_pred_probs.float(), y_real.byte()) jacc_ = jaccard(y_pred_probs.float(), y_real) ious.append(iou_.item()) losses.append(loss.item()) jaccs.append(jacc_.item()) test_loss = np.mean(losses) print( f"Loss: {np.mean(losses)}, IOU: {np.mean(ious)}, jacc: {np.mean(jaccs)}" ) scheduler.step(test_loss) if test_loss < min_loss: min_loss = test_loss save_model(model, epoch, models_dir / y_city[0].name)