Exemple #1
0
def train(args):
    # Setup Dataloader
    wc_data_loader = get_loader('doc3dwc')
    data_path = args.data_path
    wc_t_loader = wc_data_loader(data_path,
                                 is_transform=True,
                                 img_size=(args.wc_img_rows, args.wc_img_cols),
                                 augmentations=args.augmentation)
    wc_v_loader = wc_data_loader(data_path,
                                 is_transform=True,
                                 split='val',
                                 img_size=(args.wc_img_rows, args.wc_img_cols))

    wc_n_classes = wc_t_loader.n_classes
    wc_trainloader = data.DataLoader(wc_t_loader,
                                     batch_size=args.batch_size,
                                     num_workers=8,
                                     shuffle=True)
    wc_valloader = data.DataLoader(wc_v_loader,
                                   batch_size=args.batch_size,
                                   num_workers=8)

    # Setup Model
    model_wc = get_model('unetnc', wc_n_classes, in_channels=3)
    model_wc = torch.nn.DataParallel(model_wc,
                                     device_ids=range(
                                         torch.cuda.device_count()))
    model_wc.cuda()

    # Setup Dataloader
    bm_data_loader = get_loader('doc3dbmnic')
    bm_t_loader = bm_data_loader(data_path,
                                 is_transform=True,
                                 img_size=(args.bm_img_rows, args.bm_img_cols))
    bm_v_loader = bm_data_loader(data_path,
                                 is_transform=True,
                                 split='val',
                                 img_size=(args.bm_img_rows, args.bm_img_cols))

    bm_n_classes = bm_t_loader.n_classes
    bm_trainloader = data.DataLoader(bm_t_loader,
                                     batch_size=args.batch_size,
                                     num_workers=8,
                                     shuffle=True)
    bm_valloader = data.DataLoader(bm_v_loader,
                                   batch_size=args.batch_size,
                                   num_workers=8)

    # Setup Model
    model_bm = get_model('dnetccnl', bm_n_classes, in_channels=3)
    model_bm = torch.nn.DataParallel(model_bm,
                                     device_ids=range(
                                         torch.cuda.device_count()))
    model_bm.cuda()

    if os.path.isfile(args.shape_net_loc):
        print("Loading model_wc from checkpoint '{}'".format(
            args.shape_net_loc))
        checkpoint = torch.load(args.shape_net_loc)
        model_wc.load_state_dict(checkpoint['model_state'])
        print("Loaded checkpoint '{}' (epoch {})".format(
            args.shape_net_loc, checkpoint['epoch']))
    else:
        print("No model_wc checkpoint found at '{}'".format(
            args.shape_net_loc))
        exit(1)
    if os.path.isfile(args.texture_mapping_net_loc):
        print("Loading model_bm from checkpoint '{}'".format(
            args.texture_mapping_net_loc))
        checkpoint = torch.load(args.texture_mapping_net_loc)
        model_bm.load_state_dict(checkpoint['model_state'])
        print("Loaded checkpoint '{}' (epoch {})".format(
            args.texture_mapping_net_loc, checkpoint['epoch']))
    else:
        print("No model_bm checkpoint found at '{}'".format(
            args.texture_mapping_net_loc))
        exit(1)

    # Activation
    htan = nn.Hardtanh(0, 1.0)

    # Optimizer
    optimizer = torch.optim.Adam(list(model_wc.parameters()) +
                                 list(model_bm.parameters()),
                                 lr=args.l_rate,
                                 weight_decay=5e-4,
                                 amsgrad=True)

    # LR Scheduler
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode='min',
                                                       factor=0.5,
                                                       patience=5,
                                                       verbose=True)

    # Losses
    MSE = nn.MSELoss()
    loss_fn = nn.L1Loss()
    gloss = grad_loss.Gradloss(window_size=5, padding=2)
    reconst_loss = recon_lossc.Unwarploss()

    epoch_start = 0

    # Log file:
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    experiment_name = 'joint train'
    log_file_name = os.path.join(args.logdir, experiment_name + '.txt')
    if os.path.isfile(log_file_name):
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w+')

    log_file.write('\n---------------  ' + experiment_name +
                   '  ---------------\n')
    log_file.close()

    # Setup tensorboard for visualization
    if args.tboard:
        # save logs in runs/<experiment_name>
        writer = SummaryWriter(comment=experiment_name)

    best_val_mse = 99999.0
    global_step = 0
    LClambda = 0.2
    bm_img_size = (128, 128)

    alpha = 0.5
    beta = 0.5
    for epoch in range(epoch_start, args.n_epoch):
        avg_loss = 0.0

        wc_avg_l1loss = 0.0
        wc_avg_gloss = 0.0
        wc_train_mse = 0.0

        bm_avgl1loss = 0.0
        bm_avgrloss = 0.0
        bm_avgssimloss = 0.0
        bm_train_mse = 0.0

        model_wc.train()
        model_bm.train()
        if epoch == 50 and LClambda < 1.0:
            LClambda += 0.2
        for (i, (wc_images, wc_labels)), (i, (bm_images, bm_labels)) in zip(
                enumerate(wc_trainloader), enumerate(bm_trainloader)):
            wc_images = Variable(wc_images.cuda())
            wc_labels = Variable(wc_labels.cuda())

            optimizer.zero_grad()
            wc_outputs = model_wc(wc_images)
            pred_wc = htan(wc_outputs)
            g_loss = gloss(pred_wc, wc_labels)
            wc_l1loss = loss_fn(pred_wc, wc_labels)
            loss = alpha * (wc_l1loss + LClambda * g_loss)

            bm_images = Variable(bm_images.cuda())
            bm_labels = Variable(bm_labels.cuda())
            bm_input = F.interpolate(pred_wc, bm_img_size)

            target = model_bm(bm_input)
            target_nhwc = target.transpose(1, 2).transpose(2, 3)
            bm_val_l1loss = loss_fn(target_nhwc, bm_labels)
            rloss, ssim, uworg, uwpred = reconst_loss(bm_images[:, :-1, :, :],
                                                      target_nhwc, bm_labels)
            loss += beta * ((10.0 * bm_val_l1loss) + (0.5 * rloss))

            avg_loss += float(loss)

            wc_avg_l1loss += float(wc_l1loss)
            wc_avg_gloss += float(g_loss)
            wc_train_mse += float(MSE(pred_wc, wc_labels).item())

            bm_avgl1loss += float(bm_val_l1loss)
            bm_avgrloss += float(rloss)
            bm_avgssimloss += float(ssim)

            bm_train_mse += float(MSE(target_nhwc, bm_labels).item())

            loss.backward()
            optimizer.step()
            global_step += 1

            if (i + 1) % 50 == 0:
                print("Epoch[%d/%d] Batch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, i + 1, len(wc_trainloader),
                       avg_loss / 50.0))
                avg_loss = 0.0

            if args.tboard and (i + 1) % 20 == 0:
                show_wc_tnsboard(global_step, writer, wc_images, wc_labels,
                                 pred_wc, 8, 'Train Inputs', 'Train WCs',
                                 'Train pred_wc. WCs')
                writer.add_scalar('WC: L1 Loss/train', wc_avg_l1loss / (i + 1),
                                  global_step)
                writer.add_scalar('WC: Grad Loss/train',
                                  wc_avg_gloss / (i + 1), global_step)
                show_unwarp_tnsboard(bm_images, global_step, writer, uwpred,
                                     uworg, 8, 'Train GT unwarp',
                                     'Train Pred Unwarp')
                writer.add_scalar('BM: L1 Loss/train', bm_avgl1loss / (i + 1),
                                  global_step)
                writer.add_scalar('CB: Recon Loss/train',
                                  bm_avgrloss / (i + 1), global_step)
                writer.add_scalar('CB: SSIM Loss/train',
                                  bm_avgssimloss / (i + 1), global_step)

        wc_train_mse = wc_train_mse / len(wc_trainloader)
        wc_avg_l1loss = wc_avg_l1loss / len(wc_trainloader)
        wc_avg_gloss = wc_avg_gloss / len(wc_trainloader)
        print("wc Training L1:%4f" % (wc_avg_l1loss))
        print("wc Training MSE:'{}'".format(wc_train_mse))
        wc_train_losses = [wc_avg_l1loss, wc_train_mse, wc_avg_gloss]

        lrate = get_lr(optimizer)

        write_log_file(log_file_name, wc_train_losses, epoch + 1, lrate,
                       'Train', 'wc')

        bm_avgssimloss = bm_avgssimloss / len(bm_trainloader)
        bm_avgrloss = bm_avgrloss / len(bm_trainloader)
        bm_avgl1loss = bm_avgl1loss / len(bm_trainloader)
        bm_train_mse = bm_train_mse / len(bm_trainloader)
        print("bm Training L1:%4f" % (bm_avgl1loss))
        print("bm Training MSE:'{}'".format(bm_train_mse))
        bm_train_losses = [
            bm_avgl1loss, bm_train_mse, bm_avgrloss, bm_avgssimloss
        ]

        write_log_file(log_file_name, bm_train_losses, epoch + 1, lrate,
                       'Train', 'bm')

        model_wc.eval()
        model_bm.eval()

        val_mse = 0.0
        val_loss = 0.0

        wc_val_loss = 0.0
        wc_val_gloss = 0.0
        wc_val_mse = 0.0

        bm_val_l1loss = 0.0
        val_rloss = 0.0
        val_ssimloss = 0.0
        bm_val_mse = 0.0

        for (i_val, (wc_images_val,
                     wc_labels_val)), (i_val,
                                       (bm_images_val, bm_labels_val)) in tqdm(
                                           zip(enumerate(wc_valloader),
                                               enumerate(bm_valloader))):
            with torch.no_grad():
                wc_images_val = Variable(wc_images_val.cuda())
                wc_labels_val = Variable(wc_labels_val.cuda())

                wc_outputs = model_wc(wc_images_val)
                pred_val = htan(wc_outputs)
                wc_g_loss = gloss(pred_val, wc_labels_val).cpu()
                pred_val = pred_val.cpu()
                wc_labels_val = wc_labels_val.cpu()
                wc_val_loss += loss_fn(pred_val, wc_labels_val)
                wc_val_mse += float(MSE(pred_val, wc_labels_val))
                wc_val_gloss += float(wc_g_loss)

                bm_images_val = Variable(bm_images_val.cuda())
                bm_labels_val = Variable(bm_labels_val.cuda())
                bm_input = F.interpolate(pred_val, bm_img_size)
                target = model_bm(bm_input)
                target_nhwc = target.transpose(1, 2).transpose(2, 3)
                pred = target_nhwc.data.cpu()
                gt = bm_labels_val.cpu()
                bm_val_l1loss += loss_fn(target_nhwc, bm_labels_val)
                rloss, ssim, uworg, uwpred = reconst_loss(
                    bm_images_val[:, :-1, :, :], target_nhwc, bm_labels_val)
                val_rloss += float(rloss.cpu())
                val_ssimloss += float(ssim.cpu())
                bm_val_mse += float(MSE(pred, gt))
                val_loss += (alpha * wc_val_loss + beta * bm_val_l1loss)
                val_mse += (wc_val_mse + bm_val_mse)
            if args.tboard:
                show_unwarp_tnsboard(bm_images_val, epoch + 1, writer, uwpred,
                                     uworg, 8, 'Val GT unwarp',
                                     'Val Pred Unwarp')

        if args.tboard:
            show_wc_tnsboard(epoch + 1, writer, wc_images_val, wc_labels_val,
                             pred_val, 8, 'Val Inputs', 'Val WCs',
                             'Val Pred. WCs')
            writer.add_scalar('WC: L1 Loss/val', wc_val_loss, epoch + 1)
            writer.add_scalar('WC: Grad Loss/val', wc_val_gloss, epoch + 1)

            writer.add_scalar('BM: L1 Loss/val', bm_val_l1loss, epoch + 1)
            writer.add_scalar('CB: Recon Loss/val', val_rloss, epoch + 1)
            writer.add_scalar('CB: SSIM Loss/val', val_ssimloss, epoch + 1)
            writer.add_scalar('total val loss', val_loss, epoch + 1)

        wc_val_loss = wc_val_loss / len(wc_valloader)
        wc_val_mse = wc_val_mse / len(wc_valloader)
        wc_val_gloss = wc_val_gloss / len(wc_valloader)
        print("wc val loss at epoch {}:: {}".format(epoch + 1, wc_val_loss))
        print("wc val MSE: {}".format(wc_val_mse))

        bm_val_l1loss = bm_val_l1loss / len(bm_valloader)
        bm_val_mse = bm_val_mse / len(bm_valloader)
        val_ssimloss = val_ssimloss / len(bm_valloader)
        val_rloss = val_rloss / len(bm_valloader)
        print("bm val loss at epoch {}:: {}".format(epoch + 1, bm_val_l1loss))
        print("bm val mse: {}".format(bm_val_mse))

        val_loss /= len(wc_valloader)
        val_mse /= len(wc_valloader)
        print("val loss at epoch {}:: {}".format(epoch + 1, val_loss))
        print("val mse: {}".format(val_mse))

        bm_val_losses = [bm_val_l1loss, bm_val_mse, val_rloss, val_ssimloss]
        wc_val_losses = [wc_val_loss, wc_val_mse, wc_val_gloss]
        total_val_losses = [val_loss, val_mse]
        write_log_file(log_file_name, wc_val_losses, epoch + 1, lrate, 'Val',
                       'wc')
        write_log_file(log_file_name, bm_val_losses, epoch + 1, lrate, 'Val',
                       'bm')
        write_log_file(log_file_name, total_val_losses, epoch + 1, lrate,
                       'Val', 'total')

        # reduce learning rate
        sched.step(val_mse)

        if val_mse < best_val_mse:
            best_val_mse = val_mse
            state_wc = {
                'epoch': epoch + 1,
                'model_state': model_wc.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state_wc,
                args.logdir + "{}_{}_{}_{}_{}_best_wc_model.pkl".format(
                    'unetnc', epoch + 1, wc_val_mse, wc_train_mse,
                    experiment_name))
            state_bm = {
                'epoch': epoch + 1,
                'model_state': model_bm.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state_bm,
                args.logdir + "{}_{}_{}_{}_{}_best_bm_model.pkl".format(
                    'dnetccnl', epoch + 1, bm_val_mse, bm_train_mse,
                    experiment_name))

        if (epoch + 1) % 10 == 0 and epoch > 70:
            state_wc = {
                'epoch': epoch + 1,
                'model_state': model_wc.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state_wc, args.logdir + "{}_{}_{}_{}_{}_wc_model.pkl".format(
                    'unetnc', epoch + 1, wc_val_mse, wc_train_mse,
                    experiment_name))
            state_bm = {
                'epoch': epoch + 1,
                'model_state': model_bm.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state_bm, args.logdir + "{}_{}_{}_{}_{}_bm_model.pkl".format(
                    'dnetccnl', epoch + 1, bm_val_mse, bm_train_mse,
                    experiment_name))
Exemple #2
0
def train(args):

    # Setup Dataloader
    data_loader = get_loader('doc3dbmnic')
    data_path = args.data_path
    print('Starting . . .')
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(args.img_rows, args.img_cols))
    v_loader = data_loader(data_path,
                           is_transform=True,
                           split='val',
                           img_size=(args.img_rows, args.img_cols))

    n_classes = t_loader.n_classes
    print('Loading training data . . .')
    trainloader = data.DataLoader(t_loader,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True)
    print('Loading validation data . . .')
    valloader = data.DataLoader(v_loader,
                                batch_size=args.batch_size,
                                num_workers=8)

    # Setup Model
    print('Loading model . . .')
    model = get_model(args.arch, n_classes, in_channels=3)
    model = torch.nn.DataParallel(model,
                                  device_ids=range(torch.cuda.device_count()))
    model.cuda()

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.l_rate,
                                 weight_decay=5e-4,
                                 amsgrad=True)

    # LR Scheduler
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode='min',
                                                       factor=0.5,
                                                       patience=3,
                                                       verbose=True)

    # Losses
    MSE = nn.MSELoss()
    loss_fn = nn.L1Loss()
    reconst_loss = recon_lossc.Unwarploss()

    epoch_start = 0
    if args.resume is not None:
        if os.path.isfile(args.resume):
            print("Loading model and optimizer from checkpoint '{}'".format(
                args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            print("Loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            epoch_start = checkpoint['epoch']
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    # Log file:
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    # network_activation(t=[-1,1])_dataset_lossparams_augmentations_trainstart
    experiment_name = 'dnetccnl_htan_swat3dmini1kbm_l1_noaug_scratch'
    log_file_name = os.path.join(args.logdir, experiment_name + '.txt')
    if os.path.isfile(log_file_name):
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w+')

    log_file.write('\n---------------  ' + experiment_name +
                   '  ---------------\n')
    log_file.close()

    # Setup tensorboard for visualization
    if args.tboard:
        # save logs in runs/<experiment_name>
        writer = SummaryWriter(comment=experiment_name)

    best_val_uwarpssim = 99999.0
    best_val_mse = 99999.0
    global_step = 0

    for epoch in range(epoch_start, args.n_epoch):
        avg_loss = 0.0
        avgl1loss = 0.0
        avgrloss = 0.0
        avgssimloss = 0.0
        train_mse = 0.0
        model.train()

        for i, (images, labels) in enumerate(trainloader):
            images = Variable(images.cuda())
            labels = Variable(labels.cuda())
            optimizer.zero_grad()
            target = model(images[:, 3:, :, :])
            target_nhwc = target.transpose(1, 2).transpose(2, 3)
            l1loss = loss_fn(target_nhwc, labels)
            rloss, ssim, uworg, uwpred = reconst_loss(images[:, :-1, :, :],
                                                      target_nhwc, labels)
            loss = (10.0 * l1loss) + (0.5 * rloss)  # + (0.3*ssim)
            # loss=l1loss
            avgl1loss += float(l1loss)
            avg_loss += float(loss)
            avgrloss += float(rloss)
            avgssimloss += float(ssim)

            train_mse += MSE(target_nhwc, labels).item()

            loss.backward()
            optimizer.step()
            global_step += 1

            if (i + 1) % 10 == 0:
                avg_loss = avg_loss / 10
                print("Epoch[%d/%d] Batch [%d/%d] Loss: %.4f" %
                      (epoch + 1, args.n_epoch, i + 1, len(trainloader),
                       avg_loss))
                avg_loss = 0.0

            if args.tboard and (i + 1) % 10 == 0:
                show_unwarp_tnsboard(global_step, writer, uwpred, uworg, 8,
                                     'Train GT unwarp', 'Train Pred Unwarp')
                writer.add_scalars(
                    'Train', {
                        'BM_L1 Loss/train': avgl1loss / (i + 1),
                        'CB_Recon Loss/train': avgrloss / (i + 1),
                        'CB_SSIM Loss/train': avgssimloss / (i + 1)
                    }, global_step)
                # writer.add_scalar('BM: L1 Loss/train',
                #                   avgl1loss/(i+1), global_step)
                # writer.add_scalar('CB: Recon Loss/train',
                #                   avgrloss/(i+1), global_step)
                # writer.add_scalar('CB: SSIM Loss/train',
                #                   avgssimloss/(i+1), global_step)

        avgssimloss = avgssimloss / len(trainloader)
        avgrloss = avgrloss / len(trainloader)
        avgl1loss = avgl1loss / len(trainloader)
        train_mse = train_mse / len(trainloader)
        print("Training L1:%4f" % (avgl1loss))
        print("Training MSE:'{}'".format(train_mse))
        train_losses = [avgl1loss, train_mse, avgrloss, avgssimloss]
        lrate = get_lr(optimizer)
        write_log_file(log_file_name, train_losses, epoch + 1, lrate, 'Train')

        if args.tboard:
            writer.add_scalar('BM: L1 Loss/train', avgl1loss, epoch + 1)
            writer.add_scalar('CB: Recon Loss/train', avgrloss, epoch + 1)
            writer.add_scalar('CB: SSIM Loss/train', avgssimloss, epoch + 1)
            writer.add_scalar('MSE: MSE/train', train_mse, epoch + 1)

        model.eval()
        val_loss = 0.0
        val_l1loss = 0.0
        val_mse = 0.0
        val_rloss = 0.0
        val_ssimloss = 0.0

        for i_val, (images_val, labels_val) in tqdm(enumerate(valloader)):
            with torch.no_grad():
                images_val = Variable(images_val.cuda())
                labels_val = Variable(labels_val.cuda())
                target = model(images_val[:, 3:, :, :])
                target_nhwc = target.transpose(1, 2).transpose(2, 3)
                pred = target_nhwc.data.cpu()
                gt = labels_val.cpu()
                l1loss = loss_fn(target_nhwc, labels_val)
                rloss, ssim, uworg, uwpred = reconst_loss(
                    images_val[:, :-1, :, :], target_nhwc, labels_val)
                val_l1loss += float(l1loss.cpu())
                val_rloss += float(rloss.cpu())
                val_ssimloss += float(ssim.cpu())
                val_mse += float(MSE(pred, gt))
            if args.tboard:
                show_unwarp_tnsboard(epoch + 1, writer, uwpred, uworg, 8,
                                     'Val GT unwarp', 'Val Pred Unwarp')

        val_l1loss = val_l1loss / len(valloader)
        val_mse = val_mse / len(valloader)
        val_ssimloss = val_ssimloss / len(valloader)
        val_rloss = val_rloss / len(valloader)
        print("val loss at epoch {}:: {}".format(epoch + 1, val_l1loss))
        print("val mse: {}".format(val_mse))
        val_losses = [val_l1loss, val_mse, val_rloss, val_ssimloss]
        write_log_file(log_file_name, val_losses, epoch + 1, lrate, 'Val')
        if args.tboard:
            # log the val losses
            writer.add_scalar('BM: L1 Loss/val', val_l1loss, epoch + 1)
            writer.add_scalar('CB: Recon Loss/val', val_rloss, epoch + 1)
            writer.add_scalar('CB: SSIM Loss/val', val_ssimloss, epoch + 1)
            writer.add_scalar('MSE: MSE/val', val_mse, epoch + 1)

        if args.tboard:
            # plot train against val
            writer.add_scalars('BM_L1_Loss', {
                'train': avgl1loss,
                'val': val_l1loss
            }, epoch + 1)
            writer.add_scalars('CB_Recon_Loss', {
                'train': avgrloss,
                'val': val_rloss
            }, epoch + 1)
            writer.add_scalars('CB_SSIM_Loss', {
                'train': avgssimloss,
                'val': val_ssimloss
            }, epoch + 1)
            writer.add_scalars('MSE_Mean_square_error', {
                'train': train_mse,
                'val': val_mse
            }, epoch + 1)

        # reduce learning rate
        sched.step(val_mse)

        if val_mse < best_val_mse:
            best_val_mse = val_mse
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state, args.logdir + "{}_{}_{}_{}_{}_best_model.pkl".format(
                    args.arch, epoch + 1, val_mse, train_mse, experiment_name))

        if (epoch + 1) % 10 == 0:
            state = {
                'epoch': epoch + 1,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict(),
            }
            torch.save(
                state, args.logdir + "{}_{}_{}_{}_{}_model.pkl".format(
                    args.arch, epoch + 1, val_mse, train_mse, experiment_name))
def train(n_epoch=50, batch_size=32, resume=False, wc_path='', bm_path=''):
    wc_model_name = 'unetnc'
    bm_model_name = 'dnetccnl'

    # Setup dataloader
    data_path = 'C:/Users/yuttapichai.lam/dev-environment/doc3d'
    data_loader = get_loader('doc3djoint')
    t_loader = data_loader(data_path,
                           is_transform=True,
                           img_size=(256, 256),
                           bm_size=(128, 128))
    v_loader = data_loader(data_path,
                           split='val',
                           is_transform=True,
                           img_size=(256, 256),
                           bm_size=(128, 128))

    trainloader = data.DataLoader(t_loader,
                                  batch_size=batch_size,
                                  num_workers=8,
                                  shuffle=True)
    valloader = data.DataLoader(v_loader, batch_size=batch_size, num_workers=8)

    # Last layer activation
    htan = nn.Hardtanh(0, 1.0)

    # Load models
    print('Loading')
    wc_model = get_model(wc_model_name, n_classes=3, in_channels=3)
    wc_model = torch.nn.DataParallel(wc_model,
                                     device_ids=range(
                                         torch.cuda.device_count()))
    wc_model.cuda()
    bm_model = get_model(bm_model_name, n_classes=2, in_channels=3)
    bm_model = torch.nn.DataParallel(bm_model,
                                     device_ids=range(
                                         torch.cuda.device_count()))
    bm_model.cuda()

    # Setup optimizer and learning rate reduction
    print('Setting optimizer')
    optimizer = torch.optim.Adam([{
        'params': wc_model.parameters()
    }, {
        'params': bm_model.parameters()
    }],
                                 lr=1e-4,
                                 weight_decay=5e-4,
                                 amsgrad=True)
    schedule = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                          mode='min',
                                                          factor=0.5,
                                                          patience=3,
                                                          verbose=True)

    # Setup losses
    MSE = nn.MSELoss()
    loss_fn = nn.L1Loss()
    reconst_loss = recon_lossc.Unwarploss()
    g_loss = grad_loss.Gradloss(window_size=5, padding=2)

    epoch_start = 0

    if resume:
        print('Resume from previous state')
        wc_chkpnt = torch.load(wc_path)
        wc_model.load_state_dict(wc_chkpnt['model_state'])
        bm_chkpnt = torch.load(bm_path)
        bm_model.load_state_dict(bm_chkpnt['model_state'])
        # optimizer.load_state_dict(
        #     [wc_chkpnt['optimizer_state'], bm_chkpnt['optimizer_state']])
        epoch_start = bm_chkpnt['epoch']

    best_valwc_mse = 9999999.0
    best_valbm_mse = 9999999.0
    print(f'Start from epoch {epoch_start} of {n_epoch}')
    print('Starting')
    for epoch in range(epoch_start, n_epoch):
        print(f'Epoch: {epoch}')
        # Loss initialization
        avg_loss = 0.0
        avg_wcloss = 0.0
        avgwcl1loss = 0.0
        avg_gloss = 0.0
        train_wcmse = 0.0
        avg_bmloss = 0.0
        avgbml1loss = 0.0
        avgrloss = 0.0
        avgssimloss = 0.0
        train_bmmse = 0.0

        avg_const_l1 = 0.0
        avg_const_mse = 0.0

        # Start training
        wc_model.train()
        bm_model.train()

        print('Training')

        for i, (imgs, wcs, bms, recons, ims, lbls) in enumerate(trainloader):
            images = Variable(imgs.cuda())
            wc_labels = Variable(wcs.cuda())
            bm_labels = Variable(bms.cuda())
            recon_labels = Variable(recons.cuda())
            im_inputs = Variable(ims.cuda())
            labels = Variable(lbls.cuda())

            optimizer.zero_grad()

            # Train WC network
            wc_out = wc_model(images)
            wc_out = F.interpolate(wc_out,
                                   size=(256, 256),
                                   mode='bilinear',
                                   align_corners=True)
            bm_inp = F.interpolate(wc_out,
                                   size=(128, 128),
                                   mode='bilinear',
                                   align_corners=True)
            bm_inp = htan(bm_inp)
            wc_pred = htan(wc_out)

            wc_l1loss = loss_fn(wc_pred, wc_labels)
            wc_gloss = g_loss(wc_pred, wc_labels)
            wc_mse = MSE(wc_pred, wc_labels)
            wc_loss = wc_l1loss + (0.2 * wc_gloss)

            # WC Loss
            avgwcl1loss += float(wc_l1loss)
            avg_gloss += float(wc_gloss)
            train_wcmse += float(wc_mse)
            avg_wcloss += float(wc_loss)

            # Train BM network
            bm_out = bm_model(bm_inp)
            bm_out = bm_out.transpose(1, 2).transpose(2, 3)

            bm_l1loss = loss_fn(bm_out, bm_labels)
            rloss, ssim, _, _ = reconst_loss(recon_labels, bm_out, bm_labels)
            bm_mse = MSE(bm_out, bm_labels)
            bm_loss = (10.0 * bm_l1loss) + (0.5 * rloss)

            # Loss between unwarped GT and unwarped Predict
            im_ins = im_inputs[:, :3, :, :]
            bm_out = bm_out.double()
            label_in = labels[:, :3, :, :]
            bm_labels = bm_labels.double()
            uwpred = unwarp(im_ins, bm_out)
            uworg = unwarp(label_in, bm_labels)
            const_l1 = loss_fn(uwpred, uworg)
            const_mse = MSE(uwpred, uworg)

            # BM Loss
            avg_const_l1 += float(const_l1)
            avg_const_mse += float(const_mse)
            avgbml1loss += float(bm_l1loss)
            avgrloss += float(rloss)
            avgssimloss += float(ssim)
            train_bmmse += float(bm_mse)
            avg_bmloss += float(bm_loss)

            # Step loss
            loss = (0.5 * wc_loss) + (0.5 * bm_loss)
            avg_loss += float(loss)

            # print(f'Epoch[{epoch}/{n_epoch}] Loss: {loss:.6f} Const Loss: {const_l1:.6f}')
            if (i + 1) % 10 == 0:
                # Show image
                _, ax = plt.subplots(1, 2)
                ax[0].imshow(uworg[0].cpu().detach().numpy().transpose(
                    (1, 2, 0)))
                ax[1].imshow(uwpred[0].cpu().detach().numpy().transpose(
                    (1, 2, 0)))
                plt.show()
                print(
                    f'Epoch[{epoch}/{n_epoch}] Batch[{i+1}/{len(trainloader)}] Loss: {avg_loss/(i+1):.6f} Const Loss: {avg_const_l1/(i+1):.6f}'
                )

            loss.backward()
            # const_l1.backward()
            optimizer.step()

        len_trainset = len(trainloader)
        avg_const_l1 = avg_const_l1 / len_trainset
        train_wcmse = train_wcmse / len_trainset
        train_bmmse = train_bmmse / len_trainset
        train_losses = [
            avgwcl1loss / len_trainset, train_wcmse, avg_gloss / len_trainset,
            avgbml1loss / len_trainset, train_bmmse, avgrloss / len_trainset,
            avgssimloss / len_trainset, avg_const_l1,
            avg_const_mse / len_trainset
        ]
        print(
            f'WC L1 loss: {train_losses[0]} WC MSE: {train_losses[1]} WC GLoss: {train_losses[2]}'
        )
        print(
            f'BM L1 Loss: {train_losses[3]} BM MSE: {train_losses[4]} BM RLoss: {train_losses[5]} BM SSIM Loss: {train_losses[6]}'
        )
        print(
            f'Reconstruction against GT => Loss: {train_losses[7]} MSE" {train_losses[8]}'
        )
        wc_model.eval()
        bm_model.eval()

        wc_val_l1 = 0.0
        wc_val_mse = 0.0
        wc_val_gloss = 0.0
        bm_val_l1 = 0.0
        bm_val_mse = 0.0
        bm_val_rloss = 0.0
        bm_val_ssim = 0.0
        avg_const_l1_val = 0.0
        avg_const_mse_val = 0.0

        print('Validating')

        for i_val, (imgs_val, wcs_val, bms_val, recons_val, ims_val,
                    lbls_val) in tqdm(enumerate(valloader)):
            with torch.no_grad():
                images_val = Variable(imgs_val.cuda())
                wc_labels_val = Variable(wcs_val.cuda())
                bm_labels_val = Variable(bms_val.cuda())
                recon_labels_val = Variable(recons_val.cuda())
                ims_labels_val = Variable(ims_val.cuda())
                labels_val = Variable(lbls_val.cuda())

                # Val WC Network
                wc_out_val = wc_model(images_val)
                wc_out_val = F.interpolate(wc_out_val,
                                           size=(256, 256),
                                           mode='bilinear',
                                           align_corners=True)
                bm_inp_val = F.interpolate(wc_out_val,
                                           size=(128, 128),
                                           mode='bilinear',
                                           align_corners=True)
                bm_inp_val = htan(bm_inp_val)
                wc_pred_val = htan(wc_out_val)

                wc_l1 = loss_fn(wc_pred_val, wc_labels_val)
                wc_gloss = g_loss(wc_pred_val, wc_labels_val)
                wc_mse = MSE(wc_pred_val, wc_labels_val)

                # Val BM network
                bm_out_val = bm_model(bm_inp_val)
                bm_out_val = bm_out_val.transpose(1, 2).transpose(2, 3)

                bm_l1 = loss_fn(bm_out_val, bm_labels_val)
                rloss, ssim, _, _ = reconst_loss(recon_labels_val, bm_out_val,
                                                 bm_labels_val)
                bm_mse = MSE(bm_out_val, bm_labels_val)

                # Loss between unwarped GT and unwarped Predict
                im_ins_val = ims_labels_val[:, :3, :, :]
                bm_out_val = bm_out_val.double()
                lbl_ins_val = labels_val[:, :3, :, :]
                bm_labels_val = bm_labels_val.double()
                uwpred_val = unwarp(im_ins_val, bm_out_val)
                uworg_val = unwarp(lbl_ins_val, bm_labels_val)
                const_l1_val = loss_fn(uwpred_val, uworg_val)
                const_mse_val = MSE(uwpred_val, uworg_val)

                # Val Loss
                avg_const_l1_val += float(const_l1_val)
                avg_const_mse_val += float(const_mse_val)
                wc_val_l1 += float(wc_l1.cpu())
                wc_val_gloss += float(wc_gloss.cpu())
                wc_val_mse += float(wc_mse.cpu())

                bm_val_l1 += float(bm_l1.cpu())
                bm_val_mse += float(bm_mse.cpu())
                bm_val_rloss += float(rloss.cpu())
                bm_val_ssim += float(ssim.cpu())

        len_valset = len(valloader)
        avg_const_l1_val = avg_const_l1_val / len_valset
        wc_val_mse = wc_val_mse / len_valset
        bm_val_mse = bm_val_mse / len_valset
        val_losses = [
            wc_val_l1 / len_valset, wc_val_mse, wc_val_gloss / len_valset,
            bm_val_l1 / len_valset, bm_val_mse, bm_val_rloss / len_valset,
            bm_val_ssim / len_valset, avg_const_l1_val,
            avg_const_mse_val / len_valset
        ]
        print(
            f'WC L1 loss: {val_losses[0]} WC MSE: {val_losses[1]} WC GLoss: {val_losses[2]}'
        )
        print(
            f'BM L1 Loss: {val_losses[3]} BM MSE: {val_losses[4]} BM RLoss: {val_losses[5]} BM SSIM Loss: {val_losses[6]}'
        )
        print(
            f'Reconstruction against GT => Loss: {val_losses[7]} MSE" {val_losses[8]}'
        )
        # Reduce learning rate
        schedule.step(bm_val_mse)

        if wc_val_mse < best_valwc_mse:
            best_valwc_mse = wc_val_mse
            state = {'epoch': epoch, 'model_state': wc_model.state_dict()}
            torch.save(
                state,
                f'./checkpoints-wc/unetnc_{epoch}_wc_{wc_val_mse}_{train_wcmse}_best_model.pkl'
            )

        if bm_val_mse < best_valbm_mse:
            best_valbm_mse = bm_val_mse
            state = {'epoch': epoch, 'model_state': bm_model.state_dict()}
            torch.save(
                state,
                f'./checkpoints-bm/dnetccnl_{epoch}_bm_{bm_val_mse}_{train_bmmse}_best_model.pkl'
            )