def train_proxnet(args): check_paths(args) # init GPU configuration args.dtype = set_gpu(args.cuda) # init seed np.random.seed(args.seed) torch.manual_seed(args.seed) # define training data train_dataset = data.MRFData(mod='train', sampling=args.sampling) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True) # init operators (subsampling + subspace dimension reduction + Fourier transformation) operator = OperatorBatch(sampling=args.sampling.upper()).cuda() H, HT = operator.forward, operator.adjoint bloch = BLOCH().cuda() # init PGD-Net (proxnet) proxnet = ProxNet(args).cuda() # init optimizer optimizer = torch.optim.Adam([{ 'params': proxnet.transformnet.parameters(), 'lr': args.lr, 'weight_decay': args.weight_decay }, { 'params': proxnet.alpha, 'lr': args.lr2 }]) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20], gamma=0.1) # init loss mse_loss = torch.nn.MSELoss() #.cuda() # init meters log = LOG(args.save_model_dir, filename=args.filename, field_name=[ 'iter', 'loss_m', 'loss_x', 'loss_y', 'loss_total', 'alpha' ]) loss_epoch = 0 loss_m_epoch, loss_x_epoch, loss_y_epoch = 0, 0, 0 # start PGD-Net training print('start training...') for e in range(args.epochs): proxnet.train() loss_m_seq = [] loss_x_seq = [] loss_y_seq = [] loss_total_seq = [] for x, m, y in train_loader: # covert data type (cuda) x, m, y = x.type(args.dtype), m.type(args.dtype), y.type( args.dtype) # add noise noise = args.noise_sigam * torch.randn(y.shape).type(args.dtype) HTy = HT(y + noise).type(args.dtype) # PGD-Net computation (iteration) # output the reconstructions (sequence) of MRF image x and its tissue property map m m_seq, x_seq = proxnet(HTy, H, HT, bloch) loss_x, loss_y, loss_m = 0, 0, 0 for t in range(args.time_step): loss_y += mse_loss(H(x_seq[t]), y) / args.time_step for i in range(3): loss_m += args.loss_weight['m'][i] * mse_loss( m_seq[-1][:, i, :, :], m[:, i, :, :]) loss_x = mse_loss(x_seq[-1], x) # compute loss loss_total = loss_m + args.loss_weight[ 'x'] * loss_x + args.loss_weight['y'] * loss_y # update gradient optimizer.zero_grad() loss_total.backward() optimizer.step() # update meters loss_m_seq.append(loss_m.item()) loss_x_seq.append(loss_x.item()) loss_y_seq.append(loss_y.item()) loss_total_seq.append(loss_total.item()) # (scheduled) update learning rate scheduler.step() # print meters loss_m_epoch = np.mean(loss_m_seq) loss_x_epoch = np.mean(loss_x_seq) loss_y_epoch = np.mean(loss_y_seq) loss_epoch = np.mean(loss_total_seq) log.record(e + 1, loss_m_epoch, loss_x_epoch, loss_y_epoch, loss_epoch, proxnet.alpha.detach().cpu().numpy()) logT( "==>Epoch {}\tloss_m: {:.6f}\tloss_x: {:.6f}\tloss_y: {:.6f}\tloss_total: {:.6f}\talpha: {}" .format(e + 1, loss_m_epoch, loss_x_epoch, loss_y_epoch, loss_epoch, proxnet.alpha.detach().cpu().numpy())) # save checkpoint if args.checkpoint_model_dir is not None and ( e + 1) % args.checkpoint_interval == 0: proxnet.eval() ckpt = { 'epoch': e + 1, 'loss_m': loss_m_epoch, 'loss_x': loss_x_epoch, 'loss_y': loss_y_epoch, 'total_loss': loss_epoch, 'net_state_dict': proxnet.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'alpha': proxnet.alpha.detach().cpu().numpy() } torch.save( ckpt, os.path.join(args.checkpoint_model_dir, 'ckp_epoch_{}.pt'.format(e))) proxnet.train() # save model proxnet.eval() state = { 'epoch': args.epochs, 'loss_m': loss_m_epoch, 'loss_x': loss_x_epoch, 'loss_y': loss_y_epoch, 'total_loss': loss_epoch, 'alpha': proxnet.alpha.detach().cpu().numpy(), 'net_state_dict': proxnet.state_dict(), 'optimizer_state_dict': optimizer.state_dict() } save_model_path = os.path.join(args.save_model_dir, log.filename + '.pt') torch.save(state, save_model_path) print("\nDone, trained model saved at", save_model_path)