Esempio n. 1
0
def train_single(model, batch, loader_val, optimizer, old_loss, old_f1, device,
                 out_dir):
    for _ in range(5):
        _, inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        with torch.set_grad_enabled(True):
            outputs = model(inputs)
            loss = dice_loss(outputs, labels)

            # backward + optimize only if in training phase
            loss.backward()
            optimizer.step()

    new_loss, new_f1 = get_loss_score(model, loader_val, device)
    _, inputs, _ = batch
    inputs = inputs.to(device)
    pred_batch = model(inputs)
    torch.save(pred_batch, out_dir)
    return out_dir, (old_loss - new_loss).cpu().item(), (new_f1 -
                                                         old_f1).cpu().item()
Esempio n. 2
0
def train(teacher, optimizer, train_loader):
    print(' --- teacher training')
    teacher.train().cuda()
    criterion = nn.BCEWithLogitsLoss()
    ll = []
    for i, (img, gt) in enumerate(train_loader):
        print('i', i)
        if torch.cuda.is_available():
            img, gt = img.cuda(), gt.cuda()

        img, gt = Variable(img), Variable(gt)

        output = teacher(img)
        output = output.clamp(min=0, max=1)
        gt = gt.clamp(min=0, max=1)
        loss = dice_loss(output, gt)
        ll.append(loss.item())

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()

    mean_dice = np.mean(ll)

    print("Average loss over this epoch:\n\tDice:{}".format(mean_dice))
Esempio n. 3
0
    def train(images, labels):
        with tf.GradientTape() as tape:
            output = model(images, training=True)
            loss = dice_loss(labels, output)
            metric = metrics(labels, output)

        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        return loss, metric
Esempio n. 4
0
def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = torch.nn.functional.binary_cross_entropy_with_logits(pred, target)

    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)

    pred_binary = normalise_mask(pred.detach().cpu().numpy())
    iou = intersection_over_union(target.detach().cpu().numpy(), pred_binary)

    loss = bce * bce_weight + dice * (1 - bce_weight)

    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['iou'] += iou * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)

    return loss
Esempio n. 5
0
def get_loss_score(model, loader, device):
    trainig = model.training
    model.eval()
    loss_sum = 0
    acc_sum = 0
    num_pts = 0
    for _, inputs, labels in loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            loss = dice_loss(outputs, labels)
            acc = calc_f1(outputs, labels)
        loss_sum += loss * inputs.size(0)
        acc_sum += acc * inputs.size(0)
        num_pts += inputs.size(0)
    model.train() if model.training else model.eval()
    return loss_sum / num_pts, acc_sum / num_pts
Esempio n. 6
0
def evaluate(teacher, val_loader):
    teacher.eval().cuda()

    criterion = nn.BCEWithLogitsLoss()
    ll = []
    with torch.no_grad():
        for i, (img, gt) in enumerate(val_loader):
            if torch.cuda.is_available():
                img, gt = img.cuda(), gt.cuda()
            img, gt = Variable(img), Variable(gt)

            output = teacher(img)
            output = output.clamp(min=0, max=1)
            gt = gt.clamp(min=0, max=1)
            loss = dice_loss(output, gt)
            ll.append(loss.item())

    mean_dice = np.mean(ll)
    print('Eval metrics:\n\tAverabe Dice loss:{}'.format(mean_dice))
Esempio n. 7
0
def evaluate_kd(student, val_loader):
    print('-------Evaluate student-------')
    student.eval().cuda()

    #criterion = torch.nn.BCEWithLogitsLoss()
    loss_summ = []
    with torch.no_grad():
        for i, (img, gt) in enumerate(val_loader):
            if torch.cuda.is_available():
                img, gt = img.cuda(), gt.cuda()
            img, gt = Variable(img), Variable(gt)

            output = student(img)
            output = output.clamp(min=0, max=1)
            loss = dice_loss(output, gt)

            loss_summ.append(loss.item())

    mean_loss = np.mean(loss_summ)
    print('- Eval metrics:\n\tAverage Dice loss:{}'.format(mean_loss))
    return mean_loss
Esempio n. 8
0
def train_student(student, teacher_outputs, optimizer, train_loader):
    print('-------Train student-------')
    #called once for each epoch
    student.train().cuda()

    summ = []
    for i, (img, gt) in enumerate(train_loader):
        teacher_output = teacher_outputs[i]
        if torch.cuda.is_available():
            img, gt = img.cuda(), gt.cuda()
            teacher_output = teacher_output.cuda()

        img, gt = Variable(img), Variable(gt)
        teacher_output = Variable(teacher_output)

        output = student(img)

        #TODO: loss is wrong
        loss = loss_fn_kd(output, teacher_output, gt)

        # clear previous gradients, compute gradients of all variables wrt loss
        optimizer.zero_grad()
        loss.backward()

        # performs updates using calculated gradients
        optimizer.step()
        if i % summary_steps == 0:
            #do i need to move it to CPU?

            metric = dice_loss(output, gt)
            summary = {'metric': metric.item(), 'loss': loss.item()}
            summ.append(summary)

    #print('Average loss over this epoch: ' + np.mean(loss_avg))
    mean_dice_coeff = np.mean([x['metric'] for x in summ])
    mean_loss = np.mean([x['loss'] for x in summ])
    print('- Train metrics:\n' +
          '\tMetric:{}\n\tLoss:{}'.format(mean_dice_coeff, mean_loss))
    return np.median(freq) / freq


weights = compute_class_weights(Class)

# Define training functions

print "Defining and compiling training functions"

prediction = lasagne.layers.get_output(simple_net_output[0])

# deb_pred = theano.function([input_var], prediction)

# Loss function
#loss = weighted_crossentropy(prediction, target_var, weight_vector)
loss = dice_loss(prediction, target_var)
loss = loss.mean()

# Add regularization
if weight_decay > 0:
    weightsl2 = regularize_network_params(simple_net_output,
                                          lasagne.regularization.l2)
    loss += weight_decay * weightsl2

# Add penalty to enforce the same number of transitions:
if penalty_transitions > 0:
    true_prediction = T.reshape(target_var, (-1, 200))
    prediction_reshape = T.reshape(prediction, (-1, 200, 2))
    penalty_loss = abs(prediction_reshape[:, :, 1] -
                       true_prediction).sum(axis=1)
    loss += penalty_transitions * penalty_loss.mean()
Esempio n. 10
0
def train_model(model,
                dataloaders,
                policy_learner,
                optimizer,
                scheduler,
                num_epochs,
                device,
                writer,
                n_images=None):
    loader = {'val': dataloaders['val']}

    # best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    if n_images is None:
        n_images = {'train': 0, 'val': 0}

    for epoch in range(num_epochs):
        loader['train'] = policy_learner()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        since = time.time()
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            # print('+++++++++ len loader', len(loader[phase]))
            if phase == 'train':
                if scheduler:
                    scheduler.step()
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0

            for enum_id, (idxs, inputs,
                          labels) in tqdm(enumerate(loader[phase]),
                                          total=len(loader[phase])):
                inputs = inputs.to(device)
                labels = labels.to(device)
                # if phase == 'train' and enum_id < 3:
                #     for idx in idxs:
                #         torch.save(torch.tensor(1),
                #                    f'tmp/trash/{policy_learner.__class__.__name__}_{epoch}_{enum_id}__{idx}'
                #                    )

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    # loss, loss_sum, loss_bce, loss_dice = calc_loss(outputs, labels, 0)
                    loss = dice_loss(outputs, labels)
                    acc_f1 = calc_f1(outputs, labels)
                    # acc_iou = calc_IOU(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        plot_grad_flow(epoch, enum_id,
                                       model.named_parameters())
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)
                n_images[phase] += inputs.size(0)

                writer.add_scalar(f'{phase}/loss',
                                  loss.data.cpu().numpy(), n_images[phase])
                # writer.add_scalar(f'{phase}/bce', loss_bce, n_images[phase])
                # writer.add_scalar(f'{phase}/dice', loss_dice, n_images[phase])

                metrics['loss'] += loss * inputs.size(0)
                metrics['f1'] += acc_f1 * inputs.size(0)
                # metrics['iou'] += acc_iou * inputs.size(0)

            print_metrics(writer, metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples
            writer.add_scalar(f'{phase}/epoch_loss', epoch_loss, epoch)
            epoch_f1 = metrics['f1'] / epoch_samples
            writer.add_scalar(f'{phase}/epoch_F1', epoch_f1, epoch)
            # epoch_iou = metrics['iou'] / epoch_samples
            # writer.add_scalar(f'{phase}/epoch_IOU', epoch_iou, epoch)

            # # deep copy the model
            # if phase == 'val' and epoch_loss < best_loss:
            #     print("saving best model")
            #     best_loss = epoch_loss
            #     best_model_wts = copy.deepcopy(model.state_dict())

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    # model.load_state_dict(best_model_wts)
    return model, n_images
def general_loss(student_output, gt):
    #use torch.nn.CrossENtropyLoss()
    loss = dice_loss(student_output, gt)
    return loss
Esempio n. 12
0
def train(dataset, segm_net, learning_rate=0.005, lr_anneal=1.0,
          weight_decay=1e-4, num_epochs=500, max_patience=100,
          optimizer='rmsprop', training_loss=['squared_error'],
          batch_size=[10, 1, 1], ae_h=False,
          dae_dict_updates={}, data_augmentation={},
          savepath=None, loadpath=None, resume=False, train_from_0_255=False,
          lmb=1, full_im_ft=False):

    #
    # Update DAE parameters
    #
    dae_dict = {'kind': 'fcn8',
                'dropout': 0.0,
                'skip': True,
                'unpool_type': 'standard',
                'n_filters': 64,
                'conv_before_pool': 1,
                'additional_pool': 0,
                'concat_h': ['input'],
                'noise': 0.0,
                'from_gt': True,
                'temperature': 1.0,
                'path_weights': '',
                'layer': 'probs_dimshuffle',
                'exp_name': '',
                'bn': 0}

    dae_dict.update(dae_dict_updates)

    #
    # Prepare load/save directories
    #
    exp_name = build_experiment_name(segm_net,
                                     training_loss=training_loss,
                                     data_aug=bool(data_augmentation),
                                     learning_rate=learning_rate,
                                     lr_anneal=lr_anneal,
                                     weight_decay=weight_decay,
                                     optimizer=optimizer, ae_h=ae_h,
                                     **dae_dict)
    if savepath is None:
        raise ValueError('A saving directory must be specified')

    loadpath_init = os.path.join(loadpath, dataset, exp_name)
    exp_name += '_ft' if full_im_ft else ''
    loadpath = os.path.join(loadpath, dataset, exp_name)
    savepath = os.path.join(savepath, dataset, exp_name)
    if not os.path.exists(savepath):
        os.makedirs(savepath)
    else:
        print('\033[93m The following folder already exists {}. '
              'It will be overwritten in a few seconds...\033[0m'.format(
                  savepath))

    print('Saving directory : ' + savepath)
    with open(os.path.join(savepath, "config.txt"), "w") as f:
        for key, value in locals().items():
            f.write('{} = {}\n'.format(key, value))

    #
    # Define symbolic variables
    #
    input_x_var = T.tensor4('input_x_var')  # tensor for input image batch
    input_mask_var = T.tensor4('input_mask_var')  # tensor for segmentation bach (input dae)
    input_concat_h_vars = [T.tensor4()] * len(dae_dict['concat_h'])  # tensor for hidden repr batch (input dae)
    target_var = T.tensor4('target_var')  # tensor for target batch
    # learning_rate = learning_rate*0.1 if full_im_ft else learning_rate
    # learning_rate = 0.01
    print learning_rate
    lr = theano.shared(np.float32(learning_rate), 'learning_rate')

    #
    # Build dataset iterator
    #
    train_iter, val_iter, _ = load_data(dataset,
                                        data_augmentation,
                                        one_hot=True,
                                        batch_size=batch_size,
                                        return_0_255=train_from_0_255,
                                        )

    n_batches_train = train_iter.nbatches
    n_batches_val = val_iter.nbatches
    n_classes = train_iter.non_void_nclasses
    void_labels = train_iter.void_labels
    nb_in_channels = train_iter.data_shape[0]
    void = n_classes if any(void_labels) else n_classes+1

    #
    # Build networks
    #

    # Check that model and dataset get along
    print 'Checking options'
    assert (segm_net == 'fcn8' and dataset == 'camvid') or \
        (segm_net == 'densenet' and dataset == 'camvid')
    assert (data_augmentation['crop_size'] == None and full_im_ft) or not full_im_ft

    # Build segmentation network
    print 'Building segmentation network'
    if segm_net == 'fcn8':
        layer_out = copy.copy(dae_dict['concat_h'])
        layer_out += [copy.copy(dae_dict['layer'])] if not dae_dict['from_gt'] else []
        fcn = buildFCN8(nb_in_channels, input_x_var, n_classes=n_classes,
                        void_labels=void_labels,
                        path_weights=WEIGHTS_PATH+dataset+'/fcn8_model.npz',
                        load_weights=True,
                        layer=layer_out)
        padding = 100
    elif segm_net == 'densenet':
        fcn = build_fcdensenet(input_x_var, nb_in_channels=nb_in_channels,
                                n_classes=n_classes,
                               layer=dae_dict['concat_h'],
                               from_gt=dae_dict['from_gt'])
        padding = 0
    elif segm_net == 'fcn_fcresnet':
        raise NotImplementedError
    else:
        raise ValueError

    # Build DAE network
    print 'Building DAE network'

    if ae_h and dae_dict['kind'] != 'standard':
        raise ValueError('Plug&Play not implemented for ' + dae_dict['kind'])
    if ae_h and 'pool' not in dae_dict['concat_h'][-1]:
        raise ValueError('Plug&Play version needs concat_h to be different than input')
    ae_h = ae_h and 'pool' in dae_dict['concat_h'][-1]

    if dae_dict['kind'] == 'standard':
        nb_features_to_concat=fcn[0].output_shape[1]
        dae = buildDAE(input_concat_h_vars, input_mask_var, n_classes,
                       nb_features_to_concat=nb_features_to_concat, padding=padding,
                       trainable=True,
                       void_labels=void_labels, load_weights=resume or full_im_ft,
                       path_weights=loadpath_init, model_name='dae_model_best.npz',
                       out_nonlin=softmax, concat_h=dae_dict['concat_h'],
                       noise=dae_dict['noise'], n_filters=dae_dict['n_filters'],
                       conv_before_pool=dae_dict['conv_before_pool'],
                       additional_pool=dae_dict['additional_pool'],
                       dropout=dae_dict['dropout'], skip=dae_dict['skip'],
                       unpool_type=dae_dict['unpool_type'],
                       bn=dae_dict['bn'], ae_h=ae_h)
    elif dae_dict['kind'] == 'fcn8':
        dae = buildFCN8_DAE(input_concat_h_vars, input_mask_var, n_classes,
                            nb_in_channels=n_classes, trainable=True,
                            load_weights=resume, pretrained=True, pascal=True,
                            concat_h=dae_dict['concat_h'], noise=dae_dict['noise'],
                            dropout=dae_dict['dropout'],
                            path_weights=os.path.join('/'.join(loadpath_init.split('/')[:-1]),
                            dae_dict['path_weights']),
                            model_name='dae_model_best.npz')
    elif dae_dict['kind'] == 'contextmod':
        dae = buildDAE_contextmod(input_concat_h_vars, input_mask_var, n_classes,
                                  path_weights=loadpath_init,
                                  model_name='dae_model.npz',
                                  trainable=True, load_weights=resume,
                                  out_nonlin=softmax, noise=dae_dict['noise'],
                                  concat_h=dae_dict['concat_h'])
    else:
        raise ValueError('Unknown dae kind')

    #
    # Define and compile theano functions
    #

    # training functions
    print "Defining and compiling training functions"

    # fcn prediction
    fcn_prediction = lasagne.layers.get_output(fcn, deterministic=True, batch_norm_use_averages=False)

    # select prediction layers (pooling and upsampling layers)
    dae_all_lays = lasagne.layers.get_all_layers(dae)
    if dae_dict['kind'] != 'contextmod':
        dae_lays = [l for l in dae_all_lays
                    if isinstance(l, Pool2DLayer) or
                    isinstance(l, CroppingLayer) or
                    isinstance(l, ElemwiseSumLayer) or
                    l == dae_all_lays[-1]]
        # dae_lays = dae_lays[::2]
    else:
        dae_lays = [l for l in dae_all_lays if isinstance(l, DilatedConv2DLayer) or l == dae_all_lays[-1]]

    if ae_h:
        h_ae_idx = [i for i, el in enumerate(dae_lays) if el.name == 'h_to_recon'][0]
        h_hat_idx = [i for i, el in enumerate(dae_lays) if el.name == 'h_hat'][0]

    # predictions
    dae_prediction_all = lasagne.layers.get_output(dae_lays,
                                                   batch_norm_use_averages=False)
    dae_prediction = dae_prediction_all[-1]
    dae_prediction_h = dae_prediction_all[:-1]

    test_dae_prediction_all = lasagne.layers.get_output(dae_lays,
                                                        deterministic=True,
                                                        batch_norm_use_averages=False)
    test_dae_prediction = test_dae_prediction_all[-1]
    test_dae_prediction_h = test_dae_prediction_all[:-1]

    # fetch h and h_hat if needed
    if ae_h:
        h = dae_prediction_all[h_ae_idx]
        h_hat = dae_prediction_all[h_hat_idx]
        h_test = test_dae_prediction_all[h_ae_idx]
        h_hat_test = test_dae_prediction_all[h_hat_idx]

    # loss
    loss = 0
    test_loss = 0

    # Convert DAE prediction to 2D
    dae_prediction_2D = dae_prediction.dimshuffle((0, 2, 3, 1))
    sh = dae_prediction_2D.shape
    dae_prediction_2D = dae_prediction_2D.reshape((T.prod(sh[:3]), sh[3]))

    test_dae_prediction_2D = test_dae_prediction.dimshuffle((0, 2, 3, 1))
    sh = test_dae_prediction_2D.shape
    test_dae_prediction_2D = test_dae_prediction_2D.reshape((T.prod(sh[:3]),
                                                            sh[3]))
    # Convert target to 2D
    target_var_2D = target_var.dimshuffle((0, 2, 3, 1))
    sh = target_var_2D.shape
    target_var_2D = target_var_2D.reshape((T.prod(sh[:3]), sh[3]))

    if 'crossentropy' in training_loss:
        # Compute loss
        loss += crossentropy(dae_prediction_2D, target_var_2D, void_labels,
                             one_hot=True)
        test_loss += crossentropy(test_dae_prediction_2D, target_var_2D,
                                  void_labels, one_hot=True)
    if 'dice' in training_loss:
        loss += dice_loss(dae_prediction, target_var, void_labels)
        test_loss += dice_loss(test_dae_prediction, target_var, void_labels)

    test_mse_loss = squared_error(test_dae_prediction, target_var, void)
    if 'squared_error' in training_loss:
        mse_loss = squared_error(dae_prediction, target_var, void)
        loss += lmb*mse_loss
        test_loss += lmb*test_mse_loss

    # Add intermediate losses
    if 'squared_error_h' in training_loss:
        # extract input layers and create dictionary
        dae_input_lays = [l for l in dae_all_lays if isinstance(l, InputLayer)]
        inputs = {dae_input_lays[0]: target_var[:, :void, :, :], dae_input_lays[-1]:target_var[:, :void, :, :]}
        for idx, val in enumerate(input_concat_h_vars):
            inputs[dae_input_lays[idx+1]] = val

        test_dae_prediction_all_gt = lasagne.layers.get_output(dae_lays,
                                                               inputs=inputs,
                                                               deterministic=True,
                                                               batch_norm_use_averages=False)
        test_dae_prediction_h_gt = test_dae_prediction_all_gt[:-1]

        loss += squared_error_h(dae_prediction_h, test_dae_prediction_h_gt)
        test_loss += squared_error_h(test_dae_prediction_h, test_dae_prediction_h_gt)

    # compute jaccard
    jacc = jaccard(dae_prediction_2D, target_var_2D, n_classes, one_hot=True)
    test_jacc = jaccard(test_dae_prediction_2D, target_var_2D, n_classes, one_hot=True)

    # if reconstructing h add the corresponding loss terms
    if ae_h:
        loss += squared_error_L(h, h_hat).mean()
        test_loss += squared_error_L(h_test, h_hat_test).mean()


    # network parameters
    params = lasagne.layers.get_all_params(dae, trainable=True)

    # optimizer
    if optimizer == 'rmsprop':
        updates = lasagne.updates.rmsprop(loss, params, learning_rate=lr)
    elif optimizer == 'adam':
        updates = lasagne.updates.adam(loss, params, learning_rate=lr)
    else:
        raise ValueError('Unknown optimizer')

    # functions
    train_fn = theano.function(input_concat_h_vars + [input_mask_var, target_var],
                               loss, updates=updates)

    fcn_fn = theano.function([input_x_var], fcn_prediction)
    val_fn = theano.function(input_concat_h_vars + [input_mask_var, target_var], [test_loss, test_jacc, test_mse_loss])

    err_train = []
    err_valid = []
    jacc_val_arr = []
    mse_val_arr = []
    patience = 0

    #
    # Train
    #
    # Training main loop
    print "Start training"
    for epoch in range(num_epochs):
        # Single epoch training and validation
        start_time = time.time()

        cost_train_tot = 0
        # Train
        for i in range(n_batches_train):
            # Get minibatch
            X_train_batch, L_train_batch = train_iter.next()
            L_train_batch = L_train_batch.astype(_FLOATX)

            #####uncomment if you want to control the feasability of pooling####
            # max_n_possible_pool = np.floor(np.log2(np.array(X_train_batch.shape[2:]).min()))
            # # check if we don't ask for more poolings than possible
            # assert n_pool+additional_pool < max_n_possible_pool
            ####################################################################

            # h prediction
            H_pred_batch = fcn_fn(X_train_batch)

            if dae_dict['from_gt']:
                Y_pred_batch = L_train_batch[:, :void, :, :]
            else:
                Y_pred_batch = H_pred_batch[-1]
                H_pred_batch = H_pred_batch[:-1]

            # Training step
            cost_train = train_fn(*(H_pred_batch + [Y_pred_batch, L_train_batch]))
            cost_train_tot += cost_train

        err_train += [cost_train_tot / n_batches_train]

        # Validation
        cost_val_tot = 0
        jacc_val_tot = 0
        mse_val_tot = 0
        for i in range(n_batches_val):
            # Get minibatch
            X_val_batch, L_val_batch = val_iter.next()
            L_val_batch = L_val_batch.astype(_FLOATX)

            # h prediction
            H_pred_batch = fcn_fn(X_val_batch)

            if dae_dict['from_gt']:
                Y_pred_batch = L_val_batch[:, :void, :, :]
            else:
                Y_pred_batch = H_pred_batch[-1]
                H_pred_batch = H_pred_batch[:-1]

            # Validation step
            cost_val, jacc_val, mse_val = val_fn(*(H_pred_batch + [Y_pred_batch, L_val_batch]))
            cost_val_tot += cost_val
            jacc_val_tot += jacc_val
            mse_val_tot += mse_val

        err_valid += [cost_val_tot / n_batches_val]
        jacc_val_arr += [np.mean(jacc_val_tot[0, :] / jacc_val_tot[1, :])]
        mse_val_arr += [mse_val_tot /  n_batches_val]

        out_str = "EPOCH %i: Avg epoch training cost train %f, cost val %f," + \
                  " jacc val %f, mse val % f took %f s"
        out_str = out_str % (epoch, err_train[epoch],
                             err_valid[epoch],
                             jacc_val_arr[epoch],
                             mse_val_arr[epoch],
                             time.time() - start_time)
        print out_str

        with open(os.path.join(savepath, "output.log"), "a") as f:
            f.write(out_str + "\n")

        # update learning rate
        lr.set_value(float(lr.get_value() * lr_anneal))

        # Early stopping and saving stuff
        if epoch == 0:
            best_err_val = err_valid[epoch]
            best_jacc_val = jacc_val_arr[epoch]
            best_mse_val = mse_val_arr[epoch]
        elif epoch > 0  and err_valid[epoch] < best_err_val:
            best_err_val = err_valid[epoch]
            best_jacc_val = jacc_val_arr[epoch]
            best_mse_val = mse_val_arr[epoch]
            patience = 0
            np.savez(os.path.join(savepath, 'dae_model_best.npz'),
                     *lasagne.layers.get_all_param_values(dae))
            np.savez(os.path.join(savepath, 'dae_errors_best.npz'),
                     err_train, err_valid, jacc_val_arr, mse_val_arr)
        else:
            patience += 1
            np.savez(os.path.join(savepath, 'dae_model_last.npz'),
                     *lasagne.layers.get_all_param_values(dae))
            np.savez(os.path.join(savepath, 'dae_errors_last.npz'),
                     err_train, err_valid, jacc_val_arr, mse_val_arr)

        # Finish training if patience has expired or max nber of epochs
        # reached
        if patience == max_patience or epoch == num_epochs - 1:
            # Copy files to loadpath
            if savepath != loadpath:
                print('Copying model and other training files to {}'.format(
                    loadpath))
                copy_tree(savepath, loadpath)
            # End
            print(' Training Done !')
            return
Esempio n. 13
0
    def validate(images, labels):
        output = model(images, training=False)
        loss = dice_loss(labels, output)
        metric = metrics(labels, output)

        return loss, metric
Esempio n. 14
0
def train(
    epochs: int,
    models_dir: Path,
    x_cities: List[CityData],
    y_city: List[CityData],
    mask_dir: Path,
):
    model = UNet11().cuda()
    optimizer = Adam(model.parameters(), lr=3e-4)
    scheduler = ReduceLROnPlateau(optimizer, patience=4, factor=0.25)
    min_loss = sys.maxsize
    criterion = nn.BCEWithLogitsLoss()
    train_data = DataLoader(TrainDataset(x_cities, mask_dir),
                            batch_size=4,
                            num_workers=4,
                            shuffle=True)
    test_data = DataLoader(TestDataset(y_city, mask_dir),
                           batch_size=6,
                           num_workers=4)

    for epoch in range(epochs):
        print(f'Epoch {epoch}, lr {optimizer.param_groups[0]["lr"]}')
        print(f"    Training")

        losses = []
        ious = []
        jaccs = []

        batch_iterator = enumerate(train_data)

        model = model.train().cuda()
        for i, (x, y) in tqdm(batch_iterator):
            optimizer.zero_grad()
            x = x.cuda()
            y = y.cuda()

            y_real = y.view(-1).float()
            y_pred = model(x)
            y_pred_probs = torch.sigmoid(y_pred).view(-1)
            loss = 0.75 * criterion(y_pred.view(
                -1), y_real) + 0.25 * dice_loss(y_pred_probs, y_real)

            iou_ = iou(y_pred_probs.float(), y_real.byte())
            jacc_ = jaccard(y_pred_probs.float(), y_real)
            ious.append(iou_.item())
            losses.append(loss.item())
            jaccs.append(jacc_.item())

            loss.backward()
            optimizer.step()

        print(
            f"Loss: {np.mean(losses)}, IOU: {np.mean(ious)}, jacc: {np.mean(jaccs)}"
        )

        model = model.eval().cuda()
        losses = []
        ious = []
        jaccs = []

        with torch.no_grad():
            batch_iterator = enumerate(test_data)
            for i, (x, y) in tqdm(batch_iterator):
                x = x.cuda()
                y = y.cuda()
                y_real = y.view(-1).float()
                y_pred = model(x)
                y_pred_probs = torch.sigmoid(y_pred).view(-1)
                loss = 0.75 * criterion(y_pred.view(
                    -1), y_real) + 0.25 * dice_loss(y_pred_probs, y_real)

                iou_ = iou(y_pred_probs.float(), y_real.byte())
                jacc_ = jaccard(y_pred_probs.float(), y_real)
                ious.append(iou_.item())
                losses.append(loss.item())
                jaccs.append(jacc_.item())
            test_loss = np.mean(losses)
            print(
                f"Loss: {np.mean(losses)}, IOU: {np.mean(ious)}, jacc: {np.mean(jaccs)}"
            )

        scheduler.step(test_loss)
        if test_loss < min_loss:
            min_loss = test_loss
            save_model(model, epoch, models_dir / y_city[0].name)