def trainModelPriv(model):
    if useGPU:
        model.cuda(gpu0)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = base_lr)
    optimizer.zero_grad()
    print(model)
    curr_val1 = 0
    curr_val2 = 0
    best_val2 = 0
    val_change = False
    loss_arr1 = np.zeros([iter_size])
    loss_arr2 = np.zeros([iter_size])
    loss_arr_i = 0

    stage = 0
    print('---------------')
    print('STAGE ' + str(stage))
    print('---------------')

    for iter in range(iter_low, iter_high):
        if iter > max_iter_stage0 and stage != 1:
            print('---------------')
            print('Stage 1')
            print('---------------')
            stage = 1

        if train_method == 0:
            img_b, label_b, gif_b = PP.extractImgBatch(batch_size, img_list, img_dims, onlyLesions, 
                                                    main_folder_path = '../Data/MS2017b/', with_priv = True)
        elif train_method == 1 or train_method == 2:
            if stage == 0:
                batch_size = 5
                img_b, label_b, gif_b = PP.extractPatchBatch(batch_size, patch_size_stage0, img_list, onlyLesions,
                                                            center_pixel = to_center_pixel, 
                                                            main_folder_path = '../Data/MS2017b/', 
                                                            postfix=postfix, with_priv= True)
            else:
                batch_size = 1
                img_b, label_b, gif_b = PP.extractPatchBatch(batch_size, patch_size, img_list, onlyLesions, 
                                                    center_pixel = to_center_pixel, 
                                                    main_folder_path = '../Data/MS2017b/', 
                                                    postfix=postfix, with_priv= True)
        else:
            print('Invalid training method format')
            sys.exit()

        img_b, label_b, gif_b = AUG.augmentPatchLossy([img_b, label_b, gif_b])

        #img_b is of shape      (batch_num) x 1 x dim1 x dim2 x dim3
        #label_b is of shape    (batch_num) x 1 x dim1 x dim2 x dim3

        label_b = label_b.astype(np.int64)

        #convert label from (batch_num x 1 x dim1 x dim2 x dim3)
        #               to  ((batch_numxdim1*dim2*dim3) x 3) (one hot)
        temp = label_b.reshape([-1])
        label_b = np.zeros([temp.size, num_labels])
        label_b[np.arange(temp.size),temp] = 1
        label_b = torch.from_numpy(label_b).float()

        imgs = torch.from_numpy(img_b).float()

        if useGPU:
            imgs, label_b = Variable(imgs).cuda(gpu0), Variable(label_b).cuda(gpu0)
        else:
            imgs, label_b = Variable(imgs), Variable(label_b)

        gif_b = setupGIFVar(gif_b)

        #---------------------------------------------
        #out size is      (1, 3, dim1, dim2, dim3)
        #---------------------------------------------
        #out1 is extra info
        out1, out2 = model(imgs)

        out1 = out1.permute(0,2,3,4,1).contiguous()
        out1 = out1.view(-1, num_labels2)

        out2 = out2.permute(0,2,3,4,1).contiguous()
        out2 = out2.view(-1, num_labels)
        #---------------------------------------------
        #out size is      (1 * dim1 * dim2 * dim3, 3)
        #---------------------------------------------
        m2 = nn.Softmax()
        loss2 = lossF.simple_dice_loss3D(m2(out2), label_b)
        m1 = nn.LogSoftmax()
        loss1 = F.nll_loss(m1(out1), gif_b)

        loss1 /= iter_size
        loss2 /= iter_size

        torch.autograd.backward([loss1, loss2])

        loss_val1 = float(loss1.data.cpu().numpy())
        loss_arr1[loss_arr_i] = loss_val1

        loss_val2 = float(loss2.data.cpu().numpy())
        loss_arr2[loss_arr_i] = loss_val2

        loss_arr_i = (loss_arr_i + 1) % iter_size

        if iter % 1 == 0:
            if val_change:
                print "iter = {:6d}/{:6d}       Loss_main: {:1.6f}    Loss_secondary: {:1.6f}       Val Score: {:1.6f}      Val Score secondary: {:1.6f}     \r".format(iter-1, max_iter, loss_val2*iter_size, loss_val1*iter_size, curr_val2, curr_val1),
                sys.stdout.flush()
                print ""
                val_change = False
            print "iter = {:6d}/{:6d}       Loss_main: {:1.6f}      Loss_secondary: {:1.6f}       Val Score main: {:1.6f}      Val Score secondary: {:1.6f}     \r".format(iter, max_iter, loss_val2*iter_size, loss_val1*iter_size, curr_val2, curr_val1),
            sys.stdout.flush()
        if iter % 2000 == 0:
            val_change = True
            curr_val1, curr_val2 = EFP.evalModelX(model, num_labels, num_labels2, postfix, main_folder_path, (train_method != 0), gpu0, useGPU, eval_metric = 'iou', patch_size = patch_size, extra_patch = 5, priv_eval = True)
            if curr_val2 > best_val2:
                best_val2 = curr_val2
                torch.save(model.state_dict(), model_file_path)
                print('\nSaving better model...')
            logfile.write("iter = {:6d}/{:6d}       Loss_main: {:1.6f}      Loss_secondary: {:1.6f}       Val Score main: {:1.6f}      Val Score secondary: {:1.6f}  \n".format(iter, max_iter, np.sum(loss_arr2), np.sum(loss_arr1), curr_val2, curr_val1))
            logfile.flush()
        if iter % iter_size == 0:
            optimizer.step()
            optimizer.zero_grad()

        del out1, out2, loss1, loss2
def trainModel(model):
    if useGPU:
        model.cuda(gpu0)
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = base_lr)

    optimizer.zero_grad()
    print(model)
    curr_val = 0
    best_val = 0
    val_change = False
    loss_arr = np.zeros([iter_size])
    loss_arr_i = 0
    stage = 0
    print('---------------')
    print('STAGE ' + str(stage))
    print('---------------')

    for iter in range(iter_low, iter_high):
        if iter > max_iter_stage0 and stage != 1:
            print('---------------')
            print('Stage 1')
            print('---------------')
            stage = 1

        if train_method == 0:
            img_b, label_b, _ = PP.extractImgBatch(batch_size, img_list, img_dims, onlyLesions, 
                                                    main_folder_path = '../Data/MS2017b/')
        elif train_method == 1 or train_method == 2:
            if stage == 0:
                batch_size = 1
                img_b, label_b, _ = PP.extractPatchBatch(batch_size, patch_size_stage0, img_list, onlyLesions, center_pixel = to_center_pixel, main_folder_path = '../Data/MS2017b/', postfix=postfix)
            else:
                batch_size = 1
                img_b, label_b, _ = PP.extractPatchBatch(batch_size, patch_size, img_list, onlyLesions, center_pixel = to_center_pixel, main_folder_path = '../Data/MS2017b/', postfix=postfix)
        else:
            print('Invalid training method format')
            sys.exit()

        if stage == 0:
            img_b, label_b = AUG.augmentPatchLossLess([img_b, label_b])
        img_b, label_b = AUG.augmentPatchLossy([img_b, label_b])
        #img_b, label_b = AUG.augmentPatchLossless(img_b, label_b)
        #img_b is of shape      (batch_num) x 1 x dim1 x dim2 x dim3
        #label_b is of shape    (batch_num) x 1 x dim1 x dim2 x dim3
        #batch_num should be 1 since too memory intensive

        label_b = label_b.astype(np.int64)
        #convert label from (batch_num x 1 x dim1 x dim2 x dim3)
        #               to  ((batch_numxdim1*dim2*dim3) x 3) (one hot)
        temp = label_b.reshape([-1])
        label_b = np.zeros([temp.size, num_labels])
        label_b[np.arange(temp.size),temp] = 1
        label_b = torch.from_numpy(label_b).float()

        imgs = torch.from_numpy(img_b).float()

        if useGPU:
            imgs, label_b = Variable(imgs).cuda(gpu0), Variable(label_b).cuda(gpu0)
        else:
            imgs, label_b = Variable(imgs), Variable(label_b)

        #---------------------------------------------
        #out size is      (1, 3, dim1, dim2, dim3)
        #---------------------------------------------
        out = model(imgs)
        out = out.permute(0,2,3,4,1).contiguous()
        out = out.view(-1, num_labels)
        #---------------------------------------------
        #out size is      (1 * dim1 * dim2 * dim3, 3)
        #---------------------------------------------

        #loss function
        m = nn.Softmax()
        loss = lossF.simple_dice_loss3D(m(out), label_b)

        loss /= iter_size
        loss.backward()

        loss_val = loss.data.cpu().numpy()
        loss_arr[loss_arr_i] = loss_val
        loss_arr_i = (loss_arr_i + 1) % iter_size

        if iter % 1 == 0:
            if val_change:
                print "iter = {:6d}/{:6d}       Loss: {:1.6f}       Val Score: {:1.6f}     \r".format(iter-1, max_iter, float(loss_val)*iter_size, curr_val),
                sys.stdout.flush()
                print ""
                val_change = False
            print "iter = {:6d}/{:6d}       Loss: {:1.6f}       Val Score: {:1.6f}     \r".format(iter, max_iter, float(loss_val)*iter_size, curr_val),
            sys.stdout.flush()
        if iter % 1000 == 0:
            val_change = True
            curr_val = EF.evalModelX(model, num_labels, postfix, main_folder_path, (train_method != 0), gpu0, useGPU, eval_metric = 'iou', patch_size = patch_size, extra_patch = 5)
            if curr_val > best_val:
                best_val = curr_val
                print('\nSaving better model...')
                torch.save(model.state_dict(), model_file_path)
            logfile.write("iter = {:6d}/{:6d}       Loss: {:1.6f}       Val Score: {:1.6f}     \n".format(iter, max_iter, np.sum(loss_arr), curr_val))
            logfile.flush()
        if iter % iter_size == 0:
            optimizer.step()
            optimizer.zero_grad()

        del out, loss