Beispiel #1
0
def output_point_cloud_ply(xyzs, name, output_folder):
    if not os.path.exists(output_folder):
        mkdir_p(output_folder)
    print('write: ' + os.path.join(output_folder, name + '.ply'))
    with open(os.path.join(output_folder, name + '.ply'), 'w') as f:
        pn = xyzs.shape[0]
        f.write('ply\n')
        f.write('format ascii 1.0\n')
        f.write('element vertex %d\n' % (pn))
        f.write('property float x\n')
        f.write('property float y\n')
        f.write('property float z\n')
        f.write('end_header\n')
        for i in range(pn):
            f.write('%f %f %f\n' % (xyzs[i][0], xyzs[i][1], xyzs[i][2]))
Beispiel #2
0
def test(test_loader, model, args, save_result=False, best_epoch=None):
    global device
    model.eval()  # switch to test mode
    loss_meter = AverageMeter()
    outdir = args.checkpoint.split('/')[1]
    for data in test_loader:
        data = data.to(device)
        with torch.no_grad():
            if args.arch == 'masknet':
                mask_pred = model(data)
                mask_gt = data.mask.unsqueeze(1)
                loss = torch.nn.functional.binary_cross_entropy_with_logits(
                    mask_pred, mask_gt.float(), reduction='mean')
            elif args.arch == 'jointnet':
                data_displacement = model(data)
                y_pred = data_displacement + data.pos
                loss = 0.0
                for i in range(len(torch.unique(data.batch))):
                    y_gt_sample = data.y[data.batch == i, :]
                    y_gt_sample = y_gt_sample[:data.num_joint[i], :]
                    y_pred_sample = y_pred[data.batch == i, :]
                    loss += chamfer_distance_with_average(
                        y_pred_sample.unsqueeze(0), y_gt_sample.unsqueeze(0))
            loss_meter.update(loss.item())

            if save_result:
                output_folder = 'results/{:s}/best_{:d}/'.format(
                    outdir, best_epoch)
                if not os.path.exists(output_folder):
                    mkdir_p(output_folder)
                if args.arch == 'masknet':
                    mask_pred = torch.sigmoid(mask_pred)
                    for i in range(len(torch.unique(data.batch))):
                        mask_pred_sample = mask_pred[data.batch == i]
                        np.save(
                            os.path.join(
                                output_folder,
                                str(data.name[i].item()) + '_attn.npy'),
                            mask_pred_sample.data.cpu().numpy())
                else:
                    for i in range(len(torch.unique(data.batch))):
                        y_pred_sample = y_pred[data.batch == i, :]
                        output_point_cloud_ply(
                            y_pred_sample,
                            name=str(data.name[i].item()),
                            output_folder='results/{:s}/best_{:d}/'.format(
                                outdir, best_epoch))
    return loss_meter.avg
Beispiel #3
0
def test(test_loader, model, args, save_result=False):
    global device
    model.eval()  # switch to test mode
    loss_meter = AverageMeter()
    outdir = args.checkpoint.split('/')[-1]
    for data in test_loader:
        data = data.to(device)
        with torch.no_grad():
            skin_pred = model(data)
            skin_gt = data.skin_label[:, 0:args.nearest_bone]
            loss_mask_batch = data.loss_mask.float()[:, 0:args.nearest_bone]
            skin_gt = skin_gt * loss_mask_batch
            skin_gt = skin_gt / (torch.sum(torch.abs(skin_gt), dim=1, keepdim=True) + 1e-8)
            vert_mask = (torch.abs(skin_gt.sum(dim=1) - 1.0) < 1e-8).float()
            loss = cross_entropy_with_probs(skin_pred, skin_gt, reduction='none')
            loss = (loss * loss_mask_batch * vert_mask.unsqueeze(1)).sum() / (loss_mask_batch * vert_mask.unsqueeze(1)).sum()
            loss_meter.update(loss.item())

            if save_result:
                output_folder = 'results/{:s}/'.format(outdir)
                if not os.path.exists(output_folder):
                    mkdir_p(output_folder)
                for i in range(len(torch.unique(data.batch))):
                    print('output result for model {:d}'.format(data.name[i].item()))
                    skin_pred_i = skin_pred[data.batch == i]
                    bone_names = get_bone_names(os.path.join(args.test_folder, "{:d}_skin.txt".format(data.name[i].item())))
                    tpl_e = np.loadtxt(os.path.join(args.test_folder, "{:d}_tpl_e.txt".format(data.name[i].item()))).T
                    loss_mask_sample = data.loss_mask.float()[data.batch == i, 0:args.nearest_bone]
                    skin_pred_i = torch.softmax(skin_pred_i, dim=1)
                    skin_pred_i = skin_pred_i * loss_mask_sample
                    skin_nn_i = data.skin_nn[data.batch == i, 0:args.nearest_bone]
                    skin_pred_asarray = np.zeros((len(skin_pred_i), len(bone_names)))
                    for v in range(len(skin_pred_i)):
                        for nn_id in range(len(skin_nn_i[v, :])):
                            skin_pred_asarray[v, skin_nn_i[v, nn_id]] = skin_pred_i[v, nn_id]
                    skin_pred_asarray = post_filter(skin_pred_asarray, tpl_e, num_ring=1)
                    skin_pred_asarray[skin_pred_asarray < np.max(skin_pred_asarray, axis=1, keepdims=True) * 0.5] = 0.0
                    skin_pred_asarray = skin_pred_asarray / (skin_pred_asarray.sum(axis=1, keepdims=True) + 1e-10)
                    with open(os.path.join(output_folder, "{:d}_bone_names.txt".format(data.name[i].item())), 'w') as fout:
                        for bone_name in bone_names:
                            fout.write("{:s} {:s}\n".format(bone_name[0], bone_name[1]))
                    np.save(os.path.join(output_folder, "{:d}_full_pred.npy".format(data.name[i].item())), skin_pred_asarray)
                    skel_filename = os.path.join(args.info_folder, "{:d}.txt".format(data.name[i].item()))
                    output_rigging(skel_filename, skin_pred_asarray, output_folder, data.name[i].item())
    return loss_meter.avg
Beispiel #4
0
def test(test_loader, model, args, save_result=False, best_epoch=None):
    global device
    model.eval()  # switch to test mode
    if save_result:
        output_folder = 'results/{:s}/best_{:d}/'.format(
            args.checkpoint.split('/')[1], best_epoch)
        if not os.path.exists(output_folder):
            mkdir_p(output_folder)
    loss_meter = AverageMeter()
    for data in test_loader:
        data = data.to(device)
        with torch.no_grad():
            pre_label, label = model(data)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(
                pre_label, label.float())
            if save_result:
                connect_prob = torch.sigmoid(pre_label)
                accumulate_start_id = 0
                for i in range(len(torch.unique(data.batch))):
                    pair_idx = data.pairs[
                        accumulate_start_id:accumulate_start_id +
                        data.num_pair[i]].long()
                    connect_prob_i = connect_prob[
                        accumulate_start_id:accumulate_start_id +
                        data.num_pair[i]]
                    accumulate_start_id += data.num_pair[i]
                    cost_matrix = np.zeros(
                        (data.num_joint[i], data.num_joint[i]))
                    pair_idx = pair_idx.data.cpu().numpy()
                    cost_matrix[pair_idx[:, 0],
                                pair_idx[:,
                                         1]] = connect_prob_i.data.cpu().numpy(
                                         ).squeeze()
                    cost_matrix = 1 - cost_matrix
                    print('saving: {:s}'.format(
                        str(data.name[i].item()) + '_cost.npy'))
                    np.save(
                        os.path.join(output_folder,
                                     str(data.name[i].item()) + '_cost.npy'),
                        cost_matrix)
            loss_meter.update(loss.item(), n=len(torch.unique(data.batch)))
    return loss_meter.avg
Beispiel #5
0
def main(args):
    global device
    lowest_loss = 1e20

    # create checkpoint dir and log dir
    if not isdir(args.checkpoint):
        print("Create new checkpoint folder " + args.checkpoint)
    mkdir_p(args.checkpoint)
    if not args.resume:
        if isdir(args.logdir):
            shutil.rmtree(args.logdir)
        mkdir_p(args.logdir)

    # create model

    model = PairCls()
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            args.start_epoch = checkpoint['epoch']
            lowest_loss = checkpoint['lowest_loss']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    train_loader = DataLoader(GraphDataset(root=args.train_folder),
                              batch_size=args.train_batch,
                              shuffle=True)
    val_loader = DataLoader(GraphDataset(root=args.val_folder),
                            batch_size=args.test_batch,
                            shuffle=False)
    test_loader = DataLoader(GraphDataset(root=args.test_folder),
                             batch_size=args.test_batch,
                             shuffle=False)

    if args.evaluate:
        print('\nEvaluation only')
        test_loss = test(test_loader,
                         model,
                         args,
                         save_result=True,
                         best_epoch=args.start_epoch)
        print('test_loss {:8f}'.format(test_loss))
        return

    lr = args.lr
    logger = SummaryWriter(log_dir=args.logdir)
    for epoch in range(args.start_epoch, args.epochs):
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        train_loss = train(train_loader, model, optimizer, args)
        val_loss = test(val_loader, model, args)
        test_loss = test(test_loader, model, args, best_epoch=epoch + 1)
        print('Epoch{:d}. train_loss: {:.6f}.'.format(epoch, train_loss))
        print('Epoch{:d}. val_loss: {:.6f}.'.format(epoch, val_loss))
        print('Epoch{:d}. test_loss: {:.6f}.'.format(epoch, test_loss))

        # remember best acc and save checkpoint
        is_best = val_loss < lowest_loss
        lowest_loss = min(val_loss, lowest_loss)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'lowest_loss': lowest_loss,
                'optimizer': optimizer.state_dict()
            },
            is_best,
            checkpoint=args.checkpoint)

        info = {
            'train_loss': train_loss,
            'val_loss': val_loss,
            'test_loss': test_loss
        }
        for tag, value in info.items():
            logger.add_scalar(tag, value, epoch + 1)

    print("=> loading checkpoint '{}'".format(
        os.path.join(args.checkpoint, 'model_best.pth.tar')))
    checkpoint = torch.load(os.path.join(args.checkpoint,
                                         'model_best.pth.tar'))
    best_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})".format(
        os.path.join(args.checkpoint, 'model_best.pth.tar'), best_epoch))
    test_loss = test(test_loader,
                     model,
                     args,
                     save_result=True,
                     best_epoch=best_epoch)
    print('Best epoch:\n test_loss {:8f}'.format(test_loss))
Beispiel #6
0
def one_process(dataset_folder, start_id, end_id):
    model_list = np.loadtxt(os.path.join(dataset_folder, 'model_list.txt'),
                            dtype=int)
    model_list = model_list[start_id:end_id]
    remesh_obj_folder = os.path.join(dataset_folder, "obj_remesh")
    mkdir_p(os.path.join(dataset_folder, "volumetric_geodesic/"))

    for model_id in model_list:
        print(model_id)
        if os.path.exists(
                os.path.join(
                    dataset_folder,
                    "volumetric_geodesic/{:d}_volumetric_geo.npy".format(
                        model_id))):
            continue
        remeshed_obj_filename = os.path.join(
            dataset_folder, 'obj_remesh/{:d}.obj'.format(model_id))
        ori_obj_filename = os.path.join(dataset_folder,
                                        'obj/{:d}.obj'.format(model_id))
        info_filename = os.path.join(dataset_folder,
                                     'rig_info/{:d}.txt'.format(model_id))

        pts = np.array(
            o3d.io.read_triangle_mesh(
                os.path.join(remesh_obj_folder,
                             '{:d}.obj'.format(model_id))).vertices)

        mesh_remesh = trimesh.load(remeshed_obj_filename)
        mesh_ori = trimesh.load(ori_obj_filename)
        rig_info = Info(info_filename)
        bones, bone_name, _ = get_bones(rig_info)
        origins, ends, pts_bone_dist = pts2line(pts, bones)

        if os.path.exists(
                os.path.join(
                    dataset_folder,
                    "volumetric_geodesic/{:d}_visibility_raw.npy".format(
                        model_id))):
            pts_bone_visibility = np.load(
                os.path.join(
                    dataset_folder,
                    "volumetric_geodesic/{:d}_visibility_raw.npy".format(
                        model_id)))
        else:
            # pick one mesh with fewer faces to speed up
            if len(mesh_remesh.faces) < len(mesh_ori.faces):
                trimesh.repair.fix_normals(mesh_remesh)
                pts_bone_visibility = calc_pts2bone_visible_mat(
                    mesh_remesh, origins, ends)
            else:
                trimesh.repair.fix_normals(mesh_ori)
                pts_bone_visibility = calc_pts2bone_visible_mat(
                    mesh_ori, origins, ends)
            pts_bone_visibility = pts_bone_visibility.reshape(
                len(bones), len(pts)).transpose()
            #np.save(os.path.join(dataset_folder, "volumetric_geodesic/{:d}_visibility_raw.npy".format(model_id)), pts_bone_visibility)
        pts_bone_dist = pts_bone_dist.reshape(len(bones), len(pts)).transpose()

        # remove visible points which are too far
        if os.path.exists(
                os.path.join(
                    dataset_folder,
                    "volumetric_geodesic/{:d}_visibility_filtered.npy".format(
                        model_id))):
            pts_bone_visibility = np.load(
                os.path.join(
                    dataset_folder,
                    "volumetric_geodesic/{:d}_visibility_filtered.npy".format(
                        model_id)))
        else:
            for b in range(pts_bone_visibility.shape[1]):
                visible_pts = np.argwhere(
                    pts_bone_visibility[:, b] == 1).squeeze(1)
                if len(visible_pts) == 0:
                    continue
                threshold_b = np.percentile(pts_bone_dist[visible_pts, b], 15)
                pts_bone_visibility[pts_bone_dist[:, b] > 1.3 * threshold_b,
                                    b] = False
            #np.save(os.path.join(dataset_folder, "volumetric_geodesic/{:d}_visibility_filtered.npy".format(model_id)), pts_bone_visibility)

        mesh = o3d.io.read_triangle_mesh(
            os.path.join(remesh_obj_folder, '{:d}.obj'.format(model_id)))
        surface_geodesic = calc_surface_geodesic(mesh)

        visible_matrix = np.zeros(pts_bone_visibility.shape)
        visible_matrix[np.where(
            pts_bone_visibility == 1)] = pts_bone_dist[np.where(
                pts_bone_visibility == 1)]
        euc_dist = np.sqrt(
            np.sum((pts[np.newaxis, ...] - pts[:, np.newaxis, :])**2, axis=2))
        for c in range(visible_matrix.shape[1]):
            unvisible_pts = np.argwhere(pts_bone_visibility[:,
                                                            c] == 0).squeeze(1)
            visible_pts = np.argwhere(pts_bone_visibility[:,
                                                          c] == 1).squeeze(1)
            if len(visible_pts) == 0:
                visible_matrix[:, c] = pts_bone_dist[:, c]
                continue
            for r in unvisible_pts:
                dist1 = np.min(surface_geodesic[r, visible_pts])
                nn_visible = visible_pts[np.argmin(
                    surface_geodesic[r, visible_pts])]
                if np.isinf(dist1):
                    visible_matrix[r, c] = 8.0 + pts_bone_dist[r, c]
                else:
                    visible_matrix[r,
                                   c] = dist1 + visible_matrix[nn_visible, c]
        np.save(
            os.path.join(
                dataset_folder,
                "volumetric_geodesic/{:d}_volumetric_geo.npy".format(
                    model_id)), visible_matrix)
Beispiel #7
0
def main(args):
    global device
    best_acc = 0.0

    # create checkpoint dir and log dir
    if not isdir(args.checkpoint):
        print("Create new checkpoint folder " + args.checkpoint)
    mkdir_p(args.checkpoint)
    if not args.resume:
        if isdir(args.logdir):
            shutil.rmtree(args.logdir)
        mkdir_p(args.logdir)

    # create model
    model = ROOTNET()

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr = optimizer.param_groups[0]['lr']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))
    train_loader = DataLoader(GraphDataset(root=args.train_folder),
                              batch_size=args.train_batch,
                              shuffle=True,
                              follow_batch=['joints'])
    val_loader = DataLoader(GraphDataset(root=args.val_folder),
                            batch_size=args.test_batch,
                            shuffle=False,
                            follow_batch=['joints'])
    test_loader = DataLoader(GraphDataset(root=args.test_folder),
                             batch_size=args.test_batch,
                             shuffle=False,
                             follow_batch=['joints'])
    if args.evaluate:
        print('\nEvaluation only')
        test_loss, test_acc = test(test_loader, model)
        print('test_loss {:.8f}. test_acc: {:.6f}'.format(test_loss, test_acc))
        return
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     args.schedule,
                                                     gamma=args.gamma)
    logger = SummaryWriter(log_dir=args.logdir)
    for epoch in range(args.start_epoch, args.epochs):
        lr = scheduler.get_last_lr()
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr[0]))
        train_loss = train(train_loader, model, optimizer, args)
        val_loss, val_acc = test(val_loader, model)
        test_loss, test_acc = test(test_loader, model)
        scheduler.step()
        print('Epoch{:d}. train_loss: {:.6f}.'.format(epoch + 1, train_loss))
        print('Epoch{:d}. val_loss: {:.6f}. val_acc: {:.6f}'.format(
            epoch + 1, val_loss, val_acc))
        print('Epoch{:d}. test_loss: {:.6f}. test_acc: {:.6f}'.format(
            epoch + 1, test_loss, test_acc))

        # remember best acc and save checkpoint
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_acc': best_acc,
                'optimizer': optimizer.state_dict()
            },
            is_best,
            checkpoint=args.checkpoint)

        info = {
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_accuracy': val_acc,
            'test_loss': test_loss,
            'test_accuracy': test_acc
        }
        for tag, value in info.items():
            logger.add_scalar(tag, value, epoch + 1)
    print("=> loading checkpoint '{}'".format(
        os.path.join(args.checkpoint, 'model_best.pth.tar')))
    checkpoint = torch.load(os.path.join(args.checkpoint,
                                         'model_best.pth.tar'))
    best_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded checkpoint '{}' (epoch {})".format(
        os.path.join(args.checkpoint, 'model_best.pth.tar'), best_epoch))
    test_loss, test_acc = test(test_loader, model)
    print('Best epoch:\n test_loss {:8f} test_acc {:8f}'.format(
        test_loss, test_acc))
Beispiel #8
0
###############################################################################
###############################################################################
import os
import numpy as np

from utils.os_utils import mkdir_p
from utils.job_utils import JobScheduler

from felpy.model.wavefront import Wavefront

indir = "/gpfs/exfel/data/user/guestt/labwork/dCache/NanoKB-Pulse/EHC/"

outdir1 = "/gpfs/exfel/data/group/spb-sfx/user/guestt/h5/NanoKB-Pulse/data/EHC/integrated/"
outdir2 = "/gpfs/exfel/data/group/spb-sfx/user/guestt/h5/NanoKB-Pulse/data/EHC/cmplx/"

mkdir_p(outdir1)
mkdir_p(outdir2)


def launch():
    """
    This part launches the jobs that run in main 
    """

    cwd = os.getcwd()
    script = os.path.basename(__file__)

    js = JobScheduler(cwd + "/" + script,
                      logDir="../../logs/",
                      jobName="extractIntensity",
                      partition='exfel',
Beispiel #9
0
def main(args):
    global device
    lowest_loss = 1e20

    # create checkpoint dir and log dir
    if not isdir(args.checkpoint):
        print("Create new checkpoint folder " + args.checkpoint)
    mkdir_p(args.checkpoint)
    if not args.resume:
        if isdir(args.logdir):
            shutil.rmtree(args.logdir)
        mkdir_p(args.logdir)

    # create model
    model = models.__dict__["skinnet"](args.nearest_bone, args.Dg, args.Lf)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    lr = args.lr
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            lowest_loss = checkpoint['lowest_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr = optimizer.param_groups[0]['lr']
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = True
    print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
    train_loader = DataLoader(SkinDataset(root=args.train_folder), batch_size=args.train_batch, shuffle=True)
    val_loader = DataLoader(SkinDataset(root=args.val_folder), batch_size=args.test_batch, shuffle=False)
    test_loader = DataLoader(SkinDataset(root=args.test_folder), batch_size=args.test_batch, shuffle=False)
    if args.evaluate:
        print('\nEvaluation only')
        test_loss = test(test_loader, model, args, save_result=True)
        print('test_loss {:6f}'.format(test_loss))
        return
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.schedule, gamma=args.gamma)
    logger = SummaryWriter(log_dir=args.logdir)
    for epoch in range(args.start_epoch, args.epochs):
        lr = scheduler.get_last_lr()
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr[0]))
        train_loss = train(train_loader, model, optimizer, args)
        val_loss = test(val_loader, model, args)
        test_loss = test(test_loader, model, args)
        scheduler.step()
        print('Epoch{:d}. train_loss: {:.6f}.'.format(epoch + 1, train_loss))
        print('Epoch{:d}. val_loss: {:.6f}.'.format(epoch + 1, val_loss))
        print('Epoch{:d}. test_loss: {:.6f}.'.format(epoch + 1, test_loss))

        # remember best acc and save checkpoint
        is_best = val_loss < lowest_loss
        lowest_loss = min(val_loss, lowest_loss)
        save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'lowest_loss': lowest_loss,
                         'optimizer': optimizer.state_dict()}, is_best, checkpoint=args.checkpoint)

        info = {'train_loss': train_loss, 'val_loss': val_loss, 'test_loss': test_loss}
        for tag, value in info.items():
            logger.add_scalar(tag, value, epoch + 1)