示例#1
0
文件: spm.py 项目: schwarty/nignore
def parse_spm8_preproc(work_dir, step):
    doc = {}

    if hasattr(step, 'spatial') and hasattr(step.spatial, 'preproc'):
        doc['anatomy'] = makeup_path(
            work_dir, check_path(step.spatial.preproc.data))
        doc['wmanatomy'] = prefix_filename(doc['anatomy'], 'wm')

    if hasattr(step, 'temporal'):
        doc['n_slices'] = int(step.temporal.st.nslices)
        doc['ref_slice'] = int(step.temporal.st.refslice)
        doc['slice_order'] = step.temporal.st.so.tolist()
        doc['ta'] = float(step.temporal.st.ta)
        doc['tr'] = float(step.temporal.st.tr)
        doc['bold'] = []
        doc['swabold'] = []
        if len(step.temporal.st.scans[0].shape) == 0:
            bold = [step.temporal.st.scans]
        else:
            bold = step.temporal.st.scans
        for session in bold:
            data_dir = find_data_dir(work_dir, str(session[0]))
            doc['bold'].append(check_paths(
                [os.path.join(data_dir, os.path.split(str(x))[1])
                 for x in session]))
            doc['swabold'].append(check_paths(
                [prefix_filename(os.path.join(
                    data_dir, os.path.split(str(x))[1]), 'swa')
                for x in session]))
        doc['n_scans'] = [len(s) for s in doc['bold']]
    return doc
示例#2
0
def produce_consume():
    real_path, word_path, config_path = paths()
    check_paths(word_path, config_path)
    config = get_config(config_path)
    try:
        error = check_config(config)
    except Exception as e:
        print(type(e).__name__, e)
        exit(1)
    else:
        if error is not None:
            print(error)
            exit(1)
    q = Queue()
    consumer = Consumer(q)
    for i in range(16):
        t = Thread(target=consumer.consume_domains)
        t.daemon = True
        t.start()
    Producer(q, config, word_path).get_doms()
    q.join()
    if config['write_to_file']:
        print_red('writing to domains.json')
        p = Process(target=add_data, args=(real_path, consumer.get_domains()))
        p.start()
    print_red('sleeping zzzzz...')
    sleep(config['interval'])
示例#3
0
def parse_spm5_preproc(work_dir, step):
    doc = {}
    if hasattr(step, 'spatial') and hasattr(step.spatial, 'realign'):
        realign = step.spatial.realign.estwrite
        motion = []
        if len(realign.data[0].shape) == 0:
            realign = [realign]
        else:
            realign = realign.data
            for session in realign:
                data_dir = find_data_dir(work_dir, check_path(session[0]))
                motion.append(glob.glob(os.path.join(data_dir, 'rp_*.txt'))[0])
            doc['motion'] = motion
    if hasattr(step, 'spatial') and isinstance(step.spatial, np.ndarray):
        doc['anatomy'] = makeup_path(work_dir,
                                     check_path(step.spatial[0].preproc.data))
        doc['wmanatomy'] = prefix_filename(
            makeup_path(
                work_dir,
                check_path(step.spatial[1].normalise.write.subj.resample)),
            'w')
    if hasattr(step, 'temporal'):
        doc['n_slices'] = int(step.temporal.st.nslices)
        doc['ref_slice'] = int(step.temporal.st.refslice)
        doc['slice_order'] = step.temporal.st.so.tolist()
        doc['ta'] = float(step.temporal.st.ta)
        doc['tr'] = float(step.temporal.st.tr)
        doc['bold'] = []
        doc['swabold'] = []
        if len(step.temporal.st.scans[0].shape) == 0:
            bold = [step.temporal.st.scans]
        else:
            bold = step.temporal.st.scans
        for session in bold:
            data_dir = find_data_dir(work_dir, str(session[0]))
            doc['bold'].append(
                check_paths([
                    os.path.join(data_dir,
                                 os.path.split(str(x))[1]) for x in session
                ]))
            doc['swabold'].append(
                check_paths([
                    prefix_filename(
                        os.path.join(data_dir,
                                     os.path.split(str(x))[1]), 'swa')
                    for x in session
                ]))
        doc['n_scans'] = [len(s) for s in doc['bold']]
    return doc
示例#4
0
def main():
    args = get_args_parser()
    if args.subcommand is None:
        print("ERROR: specify either train or eval")
        sys.exit(1)
    if args.cuda and not torch.cuda.is_available():
        print("ERROR: cuda is not available, try running on CPU")
        sys.exit(1)

    if args.subcommand == "train":
        print('Starting train...')
        utils.check_paths(args)
        train(args)
    else:
        print('Starting stylization...')
        utils.check_paths(args, train=False)
        stylize(args)
示例#5
0
文件: spm.py 项目: schwarty/nignore
def parse_spm5_preproc(work_dir, step):
    doc = {}
    if hasattr(step, 'spatial') and hasattr(step.spatial, 'realign'):
        realign = step.spatial.realign.estwrite
        motion = []
        if len(realign.data[0].shape) == 0:
            realign = [realign]
        else:
            realign = realign.data
            for session in realign:
                data_dir = find_data_dir(work_dir, check_path(session[0]))
                motion.append(glob.glob(os.path.join(data_dir, 'rp_*.txt'))[0])
            doc['motion'] = motion
    if hasattr(step, 'spatial') and isinstance(step.spatial, np.ndarray):
        doc['anatomy'] = makeup_path(
            work_dir, check_path(step.spatial[0].preproc.data))
        doc['wmanatomy'] = prefix_filename(makeup_path(
            work_dir,
            check_path(step.spatial[1].normalise.write.subj.resample)),
            'w')
    if hasattr(step, 'temporal'):
        doc['n_slices'] = int(step.temporal.st.nslices)
        doc['ref_slice'] = int(step.temporal.st.refslice)
        doc['slice_order'] = step.temporal.st.so.tolist()
        doc['ta'] = float(step.temporal.st.ta)
        doc['tr'] = float(step.temporal.st.tr)
        doc['bold'] = []
        doc['swabold'] = []
        if len(step.temporal.st.scans[0].shape) == 0:
            bold = [step.temporal.st.scans]
        else:
            bold = step.temporal.st.scans
        for session in bold:
            data_dir = find_data_dir(work_dir, str(session[0]))
            doc['bold'].append(check_paths(
                [os.path.join(data_dir, os.path.split(str(x))[1])
                 for x in session]))
            doc['swabold'].append(check_paths(
                [prefix_filename(os.path.join(
                    data_dir, os.path.split(str(x))[1]), 'swa')
                for x in session]))
        doc['n_scans'] = [len(s) for s in doc['bold']]
    return doc
示例#6
0
def get_intra_preproc(mat_file, work_dir, n_scans, memory=Memory(None)):
    mat = memory.cache(load_matfile)(mat_file)['SPM']
    preproc = {}

    get_motion_file = False
    if len(n_scans) > 1:
        preproc['motion'] = []
        for session in mat.Sess:
            preproc['motion'].append(session.C.C.tolist())
            if session.C.C.size == 0:
                get_motion_file = True
    else:
        preproc['motion'] = [mat.Sess.C.C.tolist()]
        if mat.Sess.C.C.size == 0:
            get_motion_file = True

    swabold = check_paths(mat.xY.P)
    if len(nb.load(makeup_path(work_dir, swabold[0])).shape) == 4:
        swabold = np.unique(swabold)
    else:
        swabold = np.split(swabold, np.cumsum(n_scans)[:-1])

    if get_motion_file:
        preproc['motion'] = []

    for session in swabold:
        session_dir = find_data_dir(work_dir, check_path(session[0]))
        if get_motion_file:
            motion_file = glob.glob(os.path.join(session_dir, 'rp_*.txt'))[0]
            motion = np.fromfile(motion_file, sep=' ')
            motion = motion.reshape(motion.shape[0] / 6, 6)
            preproc['motion'].append(motion)

        if isinstance(session, (list, np.ndarray)):
            scans = [
                os.path.join(session_dir,
                             os.path.split(scan)[1].strip())
                for scan in session
            ]
            preproc.setdefault('swabold', []).append(scans)
            preproc.setdefault('abold', []).append(
                [strip_prefix_filename(scan, 2) for scan in scans])
            preproc.setdefault('bold', []).append(
                [strip_prefix_filename(scan, 3) for scan in scans])
        else:
            preproc.setdefault('swabold', []).append(session)
            preproc.setdefault('abold',
                               []).append(strip_prefix_filename(session, 2))
            preproc.setdefault('bold',
                               []).append(strip_prefix_filename(session, 3))

    return preproc
示例#7
0
文件: spm.py 项目: schwarty/nignore
def get_intra_preproc(mat_file, work_dir, n_scans, memory=Memory(None)):
    mat = memory.cache(load_matfile)(mat_file)['SPM']
    preproc = {}

    get_motion_file = False
    if len(n_scans) > 1:
        preproc['motion'] = []
        for session in mat.Sess:
            preproc['motion'].append(session.C.C.tolist())
            if session.C.C.size == 0:
                get_motion_file = True
    else:
        preproc['motion'] = [mat.Sess.C.C.tolist()]
        if mat.Sess.C.C.size == 0:
            get_motion_file = True

    swabold = check_paths(mat.xY.P)
    if len(nb.load(makeup_path(work_dir, swabold[0])).shape) == 4:
        swabold = np.unique(swabold)
    else:
        swabold = np.split(swabold, np.cumsum(n_scans)[:-1])

    if get_motion_file:
        preproc['motion'] = []

    for session in swabold:
        session_dir = find_data_dir(work_dir, check_path(session[0]))
        if get_motion_file:
            motion_file = glob.glob(os.path.join(session_dir, 'rp_*.txt'))[0]
            motion = np.fromfile(motion_file, sep=' ')
            motion = motion.reshape(motion.shape[0] / 6, 6)
            preproc['motion'].append(motion)

        if isinstance(session, (list, np.ndarray)):
            scans = [os.path.join(session_dir, os.path.split(scan)[1].strip())
                     for scan in session]
            preproc.setdefault('swabold', []).append(scans)
            preproc.setdefault('abold', []).append(
                [strip_prefix_filename(scan, 2) for scan in scans])
            preproc.setdefault('bold', []).append(
                [strip_prefix_filename(scan, 3) for scan in scans])
        else:
            preproc.setdefault('swabold', []).append(session)
            preproc.setdefault('abold', []).append(
                strip_prefix_filename(session, 2))
            preproc.setdefault('bold', []).append(
                strip_prefix_filename(session, 3))

    return preproc
示例#8
0
def parse_spm8_preproc(work_dir, step):
    doc = {}

    if hasattr(step, 'spatial') and hasattr(step.spatial, 'preproc'):
        doc['anatomy'] = makeup_path(work_dir,
                                     check_path(step.spatial.preproc.data))
        doc['wmanatomy'] = prefix_filename(doc['anatomy'], 'wm')

    if hasattr(step, 'temporal'):
        doc['n_slices'] = int(step.temporal.st.nslices)
        doc['ref_slice'] = int(step.temporal.st.refslice)
        doc['slice_order'] = step.temporal.st.so.tolist()
        doc['ta'] = float(step.temporal.st.ta)
        doc['tr'] = float(step.temporal.st.tr)
        doc['bold'] = []
        doc['swabold'] = []
        if len(step.temporal.st.scans[0].shape) == 0:
            bold = [step.temporal.st.scans]
        else:
            bold = step.temporal.st.scans
        for session in bold:
            data_dir = find_data_dir(work_dir, str(session[0]))
            doc['bold'].append(
                check_paths([
                    os.path.join(data_dir,
                                 os.path.split(str(x))[1]) for x in session
                ]))
            doc['swabold'].append(
                check_paths([
                    prefix_filename(
                        os.path.join(data_dir,
                                     os.path.split(str(x))[1]), 'swa')
                    for x in session
                ]))
        doc['n_scans'] = [len(s) for s in doc['bold']]
    return doc
示例#9
0
parser.add_argument("--decay-steps",
                    type=int,
                    default=100,
                    help="Èpoques de decaïment de la taxa d'aprenentatge")
parser.add_argument("--decay-rate",
                    type=float,
                    default=0.96,
                    help="Decaïment de la taxa d'aprenentatge")
parser.add_argument("--epoch-save",
                    type=int,
                    default=100,
                    help="Nombre de èpoques per desar les sortides")
opt = parser.parse_args()

# Comprovem que la ruta on es desen les sortides existeixi. En cas negatiu, es crea
utils.check_paths(opt.output_path)

# Obtenim les dimensions de la imatge amb el contingut a transferir
width, height = keras.preprocessing.image.load_img(opt.base_image_path).size
# Calculem les dimensions de les imatges generades
img_ncols = int(width * opt.img_nrows / height)

# Contruïm un model VGG19 preentrenat amb ImageNet
model = vgg19.VGG19(weights='imagenet', include_top=False)

# Obtenim les sortides simbòliques de cada capa "clau" (tenen noms únics).
outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])

# Configurem un model que retorni els valors d'activació de cada capa de VGG19 (com a un diccionari).
feature_extractor = keras.Model(inputs=model.inputs, outputs=outputs_dict)
示例#10
0
def main():
    main_arg_parser = argparse.ArgumentParser(description="parser for fast-neural-style")
    subparsers = main_arg_parser.add_subparsers(title="subcommands", dest="subcommand")

    train_arg_parser = subparsers.add_parser("train",
                                             help="parser for training arguments")
    train_arg_parser.add_argument("--epochs", type=int, default=2,
                                  help="number of training epochs, default is 2")
    train_arg_parser.add_argument("--batch-size", type=int, default=4,
                                  help="batch size for training, default is 4")
    train_arg_parser.add_argument("--dataset", type=str, required=True,
                                  help="path to training dataset, the path should point to a folder "
                                       "containing another folder with all the training images")
    train_arg_parser.add_argument("--style-image", type=str, default="images/style-images/mosaic.jpg",
                                  help="path to style-image")
    # train_arg_parser.add_argument("--vgg-model-dir", type=str, required=True,
    #                               help="directory for vgg, if model is not present in the directory it is downloaded")
    train_arg_parser.add_argument("--save-model-dir", type=str, default="ckpt",
                                  help="path to folder where trained model will be saved.")
    train_arg_parser.add_argument("--image-size", type=int, default=256,
                                  help="size of training images, default is 256 X 256")
    train_arg_parser.add_argument("--style-size", type=int, default=None,
                                  help="size of style-image, default is the original size of style image")
    train_arg_parser.add_argument("--cuda", type=int, required=True, help="set it to 1 for running on GPU, 0 for CPU")
    train_arg_parser.add_argument("--seed", type=int, default=42, help="random seed for training")
    train_arg_parser.add_argument("--content-weight", type=float, default=1.0,
                                  help="weight for content-loss, default is 1.0")
    train_arg_parser.add_argument("--style-weight", type=float, default=5.0,
                                  help="weight for style-loss, default is 5.0")
    train_arg_parser.add_argument("--lr", type=float, default=1e-3,
                                  help="learning rate, default is 0.001")
    train_arg_parser.add_argument("--log-interval", type=int, default=500,
                                  help="number of images after which the training loss is logged, default is 500")

    eval_arg_parser = subparsers.add_parser("eval", help="parser for evaluation/stylizing arguments")
    eval_arg_parser.add_argument("--content-image", type=str, required=True,
                                 help="path to content image you want to stylize")
    eval_arg_parser.add_argument("--content-scale", type=float, default=None,
                                 help="factor for scaling down the content image")
    eval_arg_parser.add_argument("--output-image", type=str, required=True,
                                 help="path for saving the output image")
    eval_arg_parser.add_argument("--model", type=str, required=True,
                                 help="saved model to be used for stylizing the image")
    eval_arg_parser.add_argument("--cuda", type=int, required=True,
                                 help="set it to 1 for running on GPU, 0 for CPU")

    args = main_arg_parser.parse_args()

    if args.subcommand is None:
        print("ERROR: specify either train or eval")
        sys.exit(1)

    if args.cuda and not torch.cuda.is_available():
        print("ERROR: cuda is not available, try running on CPU")
        sys.exit(1)

    if args.subcommand == "train":
        check_paths(args)
        train(args)
    else:
        stylize(args)
示例#11
0
def main():
    main_arg_parser = argparse.ArgumentParser(
        description="parser for neural network")
    subparsers = main_arg_parser.add_subparsers(title="subcommands",
                                                dest="subcommand")

    train_arg_parser = subparsers.add_parser(
        "train", help="parser for training arguments")
    train_arg_parser.add_argument(
        "--epochs",
        type=int,
        default=10,
        help="number of training epochs, default is 10")
    train_arg_parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="batch size for training, default is 32")
    train_arg_parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        help="path to training dataset, the path should point to a folder "
        "containing another folder with all the training images")
    train_arg_parser.add_argument(
        "--save-model-dir",
        type=str,
        required=True,
        help="path to folder where trained model will be saved.")
    train_arg_parser.add_argument(
        "--checkpoint-model-dir",
        type=str,
        default=None,
        help="path to folder where checkpoints of trained models will be saved"
    )
    train_arg_parser.add_argument(
        "--image-size",
        type=int,
        default=100,
        help="size of training images, default is 100 X 100")
    train_arg_parser.add_argument(
        "--val-rate",
        type=float,
        default=0.2,
        help="the rate of training data used as validation set")
    train_arg_parser.add_argument('--cuda',
                                  action='store_true',
                                  help='enables cuda')
    train_arg_parser.add_argument("--seed",
                                  type=int,
                                  default=42,
                                  help="random seed for training")
    train_arg_parser.add_argument("--lr",
                                  type=float,
                                  default=1e-3,
                                  help="learning rate, default is 1e-3")

    eval_arg_parser = subparsers.add_parser(
        "eval", help="parser for evaluation arguments")
    eval_arg_parser.add_argument("--dataset",
                                 type=str,
                                 required=True,
                                 help="path for test dataset")
    eval_arg_parser.add_argument("--batch-size",
                                 type=int,
                                 default=64,
                                 help="batch size for training, default is 64")
    eval_arg_parser.add_argument(
        "--image-size",
        type=int,
        default=100,
        help="size of training images, default is 100 X 100")
    eval_arg_parser.add_argument(
        "--model",
        type=str,
        required=True,
        help="path to trained model, should be a exact path.")
    eval_arg_parser.add_argument('--cuda',
                                 action='store_true',
                                 help='enables cuda')

    args = main_arg_parser.parse_args()

    if args.subcommand is None:
        print("ERROR: specify either train or eval")
        sys.exit(1)
    if args.cuda and not torch.cuda.is_available():
        print("ERROR: cuda is not available, try running on CPU")
        sys.exit(1)

    if args.subcommand == "train":
        torch.manual_seed(args.seed)
        check_paths(args)
        print(args)
        train(args)
        # train the model
    else:
        print(args)
        test(args)
示例#12
0
def main():
    # add more configurations as we go, this is just a sample
    main_arg_parser = argparse.ArgumentParser(
        description="parser for fast-neural-style")
    subparsers = main_arg_parser.add_subparsers(title="subcommands",
                                                dest="subcommand")

    eval_arg_parser = subparsers.add_parser(
        "eval", help="parser for evaluation/stylizing arguments")
    eval_arg_parser.add_argument("--num-steps",
                                 type=int,
                                 default=500,
                                 help="num-steps")
    eval_arg_parser.add_argument("--lr",
                                 type=float,
                                 default=1,
                                 help="choose learning rate")
    eval_arg_parser.add_argument(
        "--content-image",
        type=str,
        required=True,
        help="path to content image you want to stylize")
    eval_arg_parser.add_argument(
        "--style-image",
        type=str,
        required=True,
        help="path to style image you want to stylize")
    eval_arg_parser.add_argument("--image-size",
                                 type=int,
                                 default=512,
                                 help="image size")

    eval_arg_parser.add_argument("--style-weight",
                                 type=float,
                                 default=10000,
                                 help="style_weight")
    eval_arg_parser.add_argument("--content-weight",
                                 type=float,
                                 default=0.01,
                                 help="style_weight")

    eval_arg_parser.add_argument("--output-image",
                                 type=str,
                                 required=True,
                                 help="path for saving the output image")
    eval_arg_parser.add_argument('--cuda',
                                 action='store_true',
                                 help='enables cuda')

    eval_arg_parser.add_argument("--optimizer",
                                 type=str,
                                 default="L-BFGS",
                                 help="choose optimizer")
    # regularization term
    eval_arg_parser.add_argument("--reg",
                                 type=str,
                                 default=None,
                                 help="choose regularizer")
    eval_arg_parser.add_argument("--reg-weight",
                                 type=float,
                                 default=0.0,
                                 help="choose regularization strength")

    # laplacian term
    eval_arg_parser.add_argument("--lap-weight",
                                 type=float,
                                 default=None,
                                 help="choose laplacian weights")

    # color preserving
    eval_arg_parser.add_argument('--color-prev',
                                 action='store_true',
                                 help='enables color preserving')

    args = main_arg_parser.parse_args()

    print(args)

    if args.subcommand is None:
        print("ERROR: specify either train or eval")
        sys.exit(1)
    if args.cuda and not torch.cuda.is_available():
        print("ERROR: cuda is not available, try running on CPU")
        sys.exit(1)

    if args.cuda:
        args.image_size = 512

    if args.subcommand == "train":
        pass
        # train(args)
        # train the model
    else:
        check_paths(args)
        stylize(args)
示例#13
0
def train_proxnet(args):
    check_paths(args)
    # init GPU configuration
    args.dtype = set_gpu(args.cuda)

    # init seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # define training data
    train_dataset = data.MRFData(mod='train', sampling=args.sampling)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True)

    # init operators (subsampling + subspace dimension reduction + Fourier transformation)
    operator = OperatorBatch(sampling=args.sampling.upper()).cuda()
    H, HT = operator.forward, operator.adjoint
    bloch = BLOCH().cuda()

    # init PGD-Net (proxnet)
    proxnet = ProxNet(args).cuda()

    # init optimizer
    optimizer = torch.optim.Adam([{
        'params': proxnet.transformnet.parameters(),
        'lr': args.lr,
        'weight_decay': args.weight_decay
    }, {
        'params': proxnet.alpha,
        'lr': args.lr2
    }])

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[20],
                                                     gamma=0.1)

    # init loss
    mse_loss = torch.nn.MSELoss()  #.cuda()

    # init meters
    log = LOG(args.save_model_dir,
              filename=args.filename,
              field_name=[
                  'iter', 'loss_m', 'loss_x', 'loss_y', 'loss_total', 'alpha'
              ])

    loss_epoch = 0
    loss_m_epoch, loss_x_epoch, loss_y_epoch = 0, 0, 0

    # start PGD-Net training
    print('start training...')
    for e in range(args.epochs):
        proxnet.train()
        loss_m_seq = []
        loss_x_seq = []
        loss_y_seq = []
        loss_total_seq = []

        for x, m, y in train_loader:
            # covert data type (cuda)
            x, m, y = x.type(args.dtype), m.type(args.dtype), y.type(
                args.dtype)
            # add noise
            noise = args.noise_sigam * torch.randn(y.shape).type(args.dtype)
            HTy = HT(y + noise).type(args.dtype)

            # PGD-Net computation (iteration)
            # output the reconstructions (sequence) of MRF image x and its tissue property map m
            m_seq, x_seq = proxnet(HTy, H, HT, bloch)

            loss_x, loss_y, loss_m = 0, 0, 0
            for t in range(args.time_step):
                loss_y += mse_loss(H(x_seq[t]), y) / args.time_step
            for i in range(3):
                loss_m += args.loss_weight['m'][i] * mse_loss(
                    m_seq[-1][:, i, :, :], m[:, i, :, :])
            loss_x = mse_loss(x_seq[-1], x)

            # compute loss
            loss_total = loss_m + args.loss_weight[
                'x'] * loss_x + args.loss_weight['y'] * loss_y

            # update gradient
            optimizer.zero_grad()
            loss_total.backward()
            optimizer.step()

            # update meters
            loss_m_seq.append(loss_m.item())
            loss_x_seq.append(loss_x.item())
            loss_y_seq.append(loss_y.item())
            loss_total_seq.append(loss_total.item())

        # (scheduled) update learning rate
        scheduler.step()

        # print meters
        loss_m_epoch = np.mean(loss_m_seq)
        loss_x_epoch = np.mean(loss_x_seq)
        loss_y_epoch = np.mean(loss_y_seq)
        loss_epoch = np.mean(loss_total_seq)

        log.record(e + 1, loss_m_epoch, loss_x_epoch, loss_y_epoch, loss_epoch,
                   proxnet.alpha.detach().cpu().numpy())
        logT(
            "==>Epoch {}\tloss_m: {:.6f}\tloss_x: {:.6f}\tloss_y: {:.6f}\tloss_total: {:.6f}\talpha: {}"
            .format(e + 1, loss_m_epoch, loss_x_epoch, loss_y_epoch,
                    loss_epoch,
                    proxnet.alpha.detach().cpu().numpy()))

        # save checkpoint
        if args.checkpoint_model_dir is not None and (
                e + 1) % args.checkpoint_interval == 0:
            proxnet.eval()
            ckpt = {
                'epoch': e + 1,
                'loss_m': loss_m_epoch,
                'loss_x': loss_x_epoch,
                'loss_y': loss_y_epoch,
                'total_loss': loss_epoch,
                'net_state_dict': proxnet.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'alpha': proxnet.alpha.detach().cpu().numpy()
            }
            torch.save(
                ckpt,
                os.path.join(args.checkpoint_model_dir,
                             'ckp_epoch_{}.pt'.format(e)))
            proxnet.train()

    # save model
    proxnet.eval()
    state = {
        'epoch': args.epochs,
        'loss_m': loss_m_epoch,
        'loss_x': loss_x_epoch,
        'loss_y': loss_y_epoch,
        'total_loss': loss_epoch,
        'alpha': proxnet.alpha.detach().cpu().numpy(),
        'net_state_dict': proxnet.state_dict(),
        'optimizer_state_dict': optimizer.state_dict()
    }
    save_model_path = os.path.join(args.save_model_dir, log.filename + '.pt')
    torch.save(state, save_model_path)
    print("\nDone, trained model saved at", save_model_path)