Example #1
0
def train_proxnet(args):
    check_paths(args)
    # init GPU configuration
    args.dtype = set_gpu(args.cuda)

    # init seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # define training data
    train_dataset = data.MRFData(mod='train', sampling=args.sampling)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    # init operators (subsampling + subspace dimension reduction + Fourier transformation)
    operator = OperatorBatch(sampling=args.sampling.upper()).cuda()
    H, HT = operator.forward, operator.adjoint
    bloch = BLOCH().cuda()

    # init PGD-Net (proxnet)
    proxnet = ProxNet(args).cuda()

    # init optimizer
    optimizer = torch.optim.Adam([{
        'params': proxnet.transformnet.parameters(),
        'lr': args.lr,
        'weight_decay': args.weight_decay
    }, {
        'params': proxnet.alpha,
        'lr': args.lr2
    }])

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[20],
                                                     gamma=0.1)

    # init loss
    mse_loss = torch.nn.MSELoss()  #.cuda()

    # init meters
    log = LOG(args.save_model_dir,
              filename=args.filename,
              field_name=[
                  'iter', 'loss_m', 'loss_x', 'loss_y', 'loss_total', 'alpha'
              ])

    loss_epoch = 0
    loss_m_epoch, loss_x_epoch, loss_y_epoch = 0, 0, 0

    # start PGD-Net training
    print('start training...')
    for e in range(args.epochs):
        proxnet.train()
        loss_m_seq = []
        loss_x_seq = []
        loss_y_seq = []
        loss_total_seq = []

        for x, m, y in train_loader:
            # covert data type (cuda)
            x, m, y = x.type(args.dtype), m.type(args.dtype), y.type(
                args.dtype)
            # add noise
            noise = args.noise_sigam * torch.randn(y.shape).type(args.dtype)
            HTy = HT(y + noise).type(args.dtype)

            # PGD-Net computation (iteration)
            # output the reconstructions (sequence) of MRF image x and its tissue property map m
            m_seq, x_seq = proxnet(HTy, H, HT, bloch)

            loss_x, loss_y, loss_m = 0, 0, 0
            for t in range(args.time_step):
                loss_y += mse_loss(H(x_seq[t]), y) / args.time_step
            for i in range(3):
                loss_m += args.loss_weight['m'][i] * mse_loss(
                    m_seq[-1][:, i, :, :], m[:, i, :, :])
            loss_x = mse_loss(x_seq[-1], x)

            # compute loss
            loss_total = loss_m + args.loss_weight[
                'x'] * loss_x + args.loss_weight['y'] * loss_y

            # update gradient
            optimizer.zero_grad()
            loss_total.backward()
            optimizer.step()

            # update meters
            loss_m_seq.append(loss_m.item())
            loss_x_seq.append(loss_x.item())
            loss_y_seq.append(loss_y.item())
            loss_total_seq.append(loss_total.item())

        # (scheduled) update learning rate
        scheduler.step()

        # print meters
        loss_m_epoch = np.mean(loss_m_seq)
        loss_x_epoch = np.mean(loss_x_seq)
        loss_y_epoch = np.mean(loss_y_seq)
        loss_epoch = np.mean(loss_total_seq)

        log.record(e + 1, loss_m_epoch, loss_x_epoch, loss_y_epoch, loss_epoch,
                   proxnet.alpha.detach().cpu().numpy())
        logT(
            "==>Epoch {}\tloss_m: {:.6f}\tloss_x: {:.6f}\tloss_y: {:.6f}\tloss_total: {:.6f}\talpha: {}"
            .format(e + 1, loss_m_epoch, loss_x_epoch, loss_y_epoch,
                    loss_epoch,
                    proxnet.alpha.detach().cpu().numpy()))

        # save checkpoint
        if args.checkpoint_model_dir is not None and (
                e + 1) % args.checkpoint_interval == 0:
            proxnet.eval()
            ckpt = {
                'epoch': e + 1,
                'loss_m': loss_m_epoch,
                'loss_x': loss_x_epoch,
                'loss_y': loss_y_epoch,
                'total_loss': loss_epoch,
                'net_state_dict': proxnet.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'alpha': proxnet.alpha.detach().cpu().numpy()
            }
            torch.save(
                ckpt,
                os.path.join(args.checkpoint_model_dir,
                             'ckp_epoch_{}.pt'.format(e)))
            proxnet.train()

    # save model
    proxnet.eval()
    state = {
        'epoch': args.epochs,
        'loss_m': loss_m_epoch,
        'loss_x': loss_x_epoch,
        'loss_y': loss_y_epoch,
        'total_loss': loss_epoch,
        'alpha': proxnet.alpha.detach().cpu().numpy(),
        'net_state_dict': proxnet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    save_model_path = os.path.join(args.save_model_dir, log.filename + '.pt')
    torch.save(state, save_model_path)
    print("\nDone, trained model saved at", save_model_path)