Exemplo n.º 1
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)
    for k in config.keys() :
        print(k, config[k])

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    if args.checkpoint_dir:
        checkpoint_dir = args.checkpoint_dir
    else:
        checkpoint_dir = config['checkpoint_dir']
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    visualize = config['visualize']
    output_dir = config['visual_output_dir']
    if visualize is True and not output_dir:
        print('No visual output directory is provided. Checkpoint directory will be used to store the visual results')
        output_dir = checkpoint_dir

    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    eval_flag = config['eval']
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config['workers_thread']
    opt = config['optimizer']
    batch_size = config['batch_size']
    val_losses, accs, durations = [], [], []

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(template_mesh, config['downsampling_factors'])

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    normalize_transform = Normalize()
    dataset = ComaDataset(data_dir, dtype='train', split=args.split, split_term=args.split_term, pre_transform=normalize_transform)
    dataset_val = ComaDataset(data_dir, dtype='val', split=args.split, split_term=args.split_term, pre_transform=normalize_transform)
    dataset_test = ComaDataset(data_dir, dtype='test', split=args.split, split_term=args.split_term, pre_transform=normalize_transform)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers_thread)
    val_loader = DataLoader(dataset_val, batch_size=1, shuffle=True, num_workers=workers_thread)
    test_loader = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=workers_thread)

    print('Loading model')
    start_epoch = 1
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
    if opt == 'adam':
        optimizer = torch.optim.Adam(coma.parameters(), lr=lr, weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(coma.parameters(), lr=lr, weight_decay=weight_decay, momentum=0.9)
    else:
        raise Exception('No optimizer provided')

    checkpoint_file = config['checkpoint_file']
    print(checkpoint_file)
    if checkpoint_file:
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch_num']
        coma.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        #To find if this is fixed in pytorch
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    coma.to(device)

    if eval_flag:
        val_loss = evaluate(coma, output_dir, test_loader, dataset_test, template_mesh, device, visualize)
        print('val loss', val_loss)
        return

    best_val_loss = float('inf')
    val_loss_history = []

    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    log_dir = os.path.join('runs/ae', current_time)
    writer = SummaryWriter(log_dir+'-ds2_lr0.04')

    for epoch in range(start_epoch, total_epochs + 1):
        print("Training for epoch ", epoch)
        train_loss = train(coma, train_loader, len(dataset), optimizer, device)
        val_loss = evaluate(coma, output_dir, val_loader, dataset_val, template_mesh, device, epoch, visualize=visualize)

        writer.add_scalar('data/train_loss', train_loss, epoch)
        writer.add_scalar('data/val_loss', val_loss, epoch)

        print('epoch ', epoch,' Train loss ', train_loss, ' Val loss ', val_loss)
        if val_loss < best_val_loss:
            save_model(coma, optimizer, epoch, train_loss, val_loss, checkpoint_dir)
            best_val_loss = val_loss

        if epoch == total_epochs or epoch % 100 == 0:
            save_model(coma, optimizer, epoch, train_loss, val_loss, checkpoint_dir)

        val_loss_history.append(val_loss)
        val_losses.append(best_val_loss)

        if opt=='sgd':
            adjust_learning_rate(optimizer, lr_decay)

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    writer.close()
Exemplo n.º 2
0
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    normalize_transform = Normalize()
    dataset = ComaDataset(data_dir, dtype='test', split=args.split, split_term=args.split_term, pre_transform=normalize_transform)
    data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1)

    print('Loading model')
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)

    checkpoint_file = config['checkpoint_file']
    print(checkpoint_file)
    if checkpoint_file:
        checkpoint = torch.load(checkpoint_file)
        coma.load_state_dict(checkpoint['state_dict'])
    coma.to(device)

    meshviewer = MeshViewers(shape=(3, cols))
    coma.eval()

    exit = 0
    cnt = 0
    for i, data in enumerate(data_loader) :
        data = data.to(device)
Exemplo n.º 3
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    if args.checkpoint_dir:
        checkpoint_dir = args.checkpoint_dir
    else:
        checkpoint_dir = config['checkpoint_dir']
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    visualize = config['visualize']
    output_dir = config['visual_output_dir']
    if visualize is True and not output_dir:
        print(
            'No visual output directory is provided. Checkpoint directory will be used to store the visual results'
        )
        output_dir = checkpoint_dir

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    eval_flag = config['eval']
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config['workers_thread']
    opt = config['optimizer']
    batch_size = config['batch_size']
    val_losses, accs, durations = [], [], []

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    normalize_transform = Normalize()

    dataset = ComaDataset(data_dir,
                          dtype='train',
                          split=args.split,
                          split_term=args.split_term,
                          pre_transform=normalize_transform)
    print('Loading model')
    start_epoch = 1
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
    if opt == 'adam':
        optimizer = torch.optim.Adam(coma.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(coma.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    momentum=0.9)
    else:
        raise Exception('No optimizer provided')

    checkpoint_file = config['checkpoint_file']
    print(checkpoint_file)
    if checkpoint_file:
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch_num']
        coma.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        #To find if this is fixed in pytorch
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    coma.to(device)
    print('making...')
    norm = torch.load('../processed_data/processed/sliced_norm.pt')
    normalize_transform.mean = norm['mean']
    normalize_transform.std = norm['std']

    #'0512','0901','0516','0509','0507','9305','0503','4919','4902',
    files = [
        '0514', '0503', '0507', '0509', '0512', '0501', '0901', '1001', '4902',
        '4913', '4919', '9302', '9305', '12411'
    ]

    coma.eval()

    meshviewer = MeshViewers(shape=(1, 2))
    for file in files:
        #mat = np.load('../Dress Dataset/'+file+'/'+file+'_pose.npz')
        mesh_dir = os.listdir('../processed_data/' + file + '/mesh/')
        latent = []
        print(len(mesh_dir))
        for i in tqdm(range(len(mesh_dir))):
            data_file = '../processed_data/' + file + '/mesh/' + str(
                i) + '.obj'
            mesh = Mesh(filename=data_file)
            adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
            edge_index = torch.Tensor(np.vstack(
                (adjacency.row, adjacency.col))).long()
            mesh_verts = (torch.Tensor(mesh.v) -
                          normalize_transform.mean) / normalize_transform.std
            data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index)
            data = data.to(device)
            with torch.no_grad():
                out, feature = coma(data)
                latent.append(feature.cpu().detach().numpy())
            # print(feature.shape)
            if i % 50 == 0:
                expected_out = data.x
                out = out.cpu().detach(
                ) * normalize_transform.std + normalize_transform.mean
                expected_out = expected_out.cpu().detach(
                ) * normalize_transform.std + normalize_transform.mean
                out = out.numpy()
                save_obj(out, template_mesh.f + 1,
                         './vis/reconstruct_' + str(i) + '.obj')
                save_obj(expected_out, template_mesh.f + 1,
                         './vis/ori_' + str(i) + '.obj')

        np.save('./processed/0820/' + file, latent)

    if torch.cuda.is_available():
        torch.cuda.synchronize()
def main(checkpoint, config_path, output_dir):
    config = read_config(config_path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Preparing dataset')
    data_dir = config['data_dir']
    normalize_transform = Normalize()
    dataset = ComaDataset(data_dir,
                          dtype='test',
                          split='sliced',
                          split_term='sliced',
                          pre_transform=normalize_transform)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    print('Loading model')
    model = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
    checkpoint = torch.load(checkpoint)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    model.eval()
    model.to(device)

    print('Generating latent')
    data = next(iter(loader))
    with torch.no_grad():
        data = data.to(device)
        x = data.x.reshape(data.num_graphs, -1, model.filters[0])
        z = model.encoder(x)

    print('View meshes')
    meshviewer = MeshViewers(shape=(1, 1))
    for feature_index in range(z.size(1)):
        j = torch.range(-4, 4, step=0.1, device=device)
        new_z = z.expand(j.size(0), z.size(1)).clone()
        new_z[:, feature_index] *= 1 + 0.3 * j

        with torch.no_grad():
            out = model.decoder(new_z)
            out = out.detach().cpu() * dataset.std + dataset.mean

        for i in trange(out.shape[0]):
            mesh = Mesh(v=out[i], f=template_mesh.f)
            meshviewer[0][0].set_dynamic_meshes([mesh])

            f = os.path.join(output_dir, 'z{}'.format(feature_index),
                             '{:04d}.png'.format(i))
            os.makedirs(os.path.dirname(f), exist_ok=True)
            meshviewer[0][0].save_snapshot(f, blocking=True)
Exemplo n.º 5
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    if args.checkpoint_dir:
        checkpoint_dir = args.checkpoint_dir
    else:
        checkpoint_dir = config['checkpoint_dir']
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    visualize = config['visualize']
    output_dir = config['visual_output_dir']
    if visualize is True and not output_dir:
        print(
            'No visual output directory is provided. Checkpoint directory will be used to store the visual results'
        )
        output_dir = checkpoint_dir

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    eval_flag = config['eval']
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config['workers_thread']
    opt = config['optimizer']
    batch_size = config['batch_size']
    val_losses, accs, durations = [], [], []

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])
    print(len(M))

    for i in range(len(M)):
        print(M[i].v.shape)
    print('************A****************')
    for a in A:
        print(a.shape)
    print('************D****************')
    for d in D:
        print(d.shape)
    print('************U****************')
    for u in U:
        print(u.shape)

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    normalize_transform = Normalize()
    dataset = ComaDataset(data_dir,
                          dtype='train',
                          split=args.split,
                          split_term=args.split_term,
                          pre_transform=normalize_transform)
    dataset_test = ComaDataset(data_dir,
                               dtype='test',
                               split=args.split,
                               split_term=args.split_term,
                               pre_transform=normalize_transform)
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=workers_thread)
    test_loader = DataLoader(dataset_test,
                             batch_size=1,
                             shuffle=False,
                             num_workers=workers_thread)

    print('Loading model')
    start_epoch = 1
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
    if opt == 'adam':
        optimizer = torch.optim.Adam(coma.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(coma.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    momentum=0.9)
    else:
        raise Exception('No optimizer provided')

    checkpoint_file = config['checkpoint_file']

    if checkpoint_file:
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch_num']
        coma.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        #To find if this is fixed in pytorch
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    coma.to(device)

    if eval_flag:
        val_loss = evaluate(coma, output_dir, test_loader, dataset_test,
                            template_mesh, device, visualize)
        print('val loss', val_loss)
        return

    best_val_loss = float('inf')
    val_loss_history = []
    train_loss_history = []

    for epoch in range(start_epoch, total_epochs + 1):
        print("Training for epoch ", epoch)
        train_loss = train(coma, train_loader, len(dataset), optimizer, device)
        val_loss = evaluate(coma,
                            output_dir,
                            test_loader,
                            dataset_test,
                            template_mesh,
                            device,
                            visualize=visualize)

        val_loss_history.append(val_loss)
        train_loss_history.append(train_loss)

        print('epoch ', epoch, ' Train loss ', train_loss, ' Val loss ',
              val_loss)
        if val_loss < best_val_loss:
            save_model(coma, optimizer, epoch, train_loss, val_loss,
                       checkpoint_dir)
            best_val_loss = val_loss
            val_losses.append(best_val_loss)

        if opt == 'sgd':
            adjust_learning_rate(optimizer, lr_decay)

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    times = list(range(len(train_loss_history)))

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(times, train_loss_history)
    ax.plot(times, val_loss_history)
    ax.set_xlabel("iteration")
    ax.set_ylabel(" loss")
    plt.savefig(checkpoint_dir + 'result.png')
Exemplo n.º 6
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)
    print(colored(str(config), 'cyan'))

    eval_flag = config['eval']

    if not eval_flag:  #train mode : fresh or reload
        current_log_dir = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        current_log_dir = os.path.join('../Experiments/', current_log_dir)
    else:  #eval mode : save result plys
        if args.load_checkpoint_dir:
            current_log_dir = '../Eval'
        else:
            print(
                colored(
                    '*****please provide checkpoint file path to reload!*****',
                    'red'))
            return

    print(colored('logs will be saved in:{}'.format(current_log_dir),
                  'yellow'))

    if args.load_checkpoint_dir:
        load_checkpoint_dir = os.path.join('../Experiments/',
                                           args.load_checkpoint_dir,
                                           'chkpt')  #load last checkpoint
        print(
            colored('load_checkpoint_dir: {}'.format(load_checkpoint_dir),
                    'red'))

    save_checkpoint_dir = os.path.join(current_log_dir, 'chkpt')
    print(
        colored('save_checkpoint_dir: {}\n'.format(save_checkpoint_dir),
                'yellow'))
    if not os.path.exists(save_checkpoint_dir):
        os.makedirs(save_checkpoint_dir)

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)
    print(template_file_path)

    visualize = config['visualize']
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config['workers_thread']
    opt = config['optimizer']
    batch_size = config['batch_size']
    val_losses, accs, durations = [], [], []

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print(colored('\n...cuda is available...\n', 'green'))
    else:
        print(colored('\n...cuda is NOT available...\n', 'red'))

    ds_factors = config['downsampling_factors']
    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, ds_factors)

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]
    print(colored('number of nodes in encoder : {}'.format(num_nodes), 'blue'))

    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    print('*** data loaded from {} ***'.format(data_dir))

    dataset = ComaDataset(data_dir,
                          dtype='train',
                          split=args.split,
                          split_term=args.split_term)
    dataset_test = ComaDataset(data_dir,
                               dtype='test',
                               split=args.split,
                               split_term=args.split_term)
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=workers_thread)
    test_loader = DataLoader(dataset_test,
                             batch_size=1,
                             shuffle=False,
                             num_workers=workers_thread)

    print("x :\n{} for dataset[0] element".format(dataset[0]))
    print(colored(train_loader, 'red'))
    print('Loading Model : \n')
    start_epoch = 1
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)

    tbSummWriter = SummaryWriter(current_log_dir)

    print_model_summary = False
    if print_model_summary:
        print(coma)

    mrkdwn = str('<pre><code>' + str(coma) + '</code></pre>')
    tbSummWriter.add_text('tag2', mrkdwn, global_step=None, walltime=None)

    #write network architecture into text file
    logfile = os.path.join(current_log_dir, 'coma.txt')
    my_data_file = open(logfile, 'w')
    my_data_file.write(str(coma))
    my_data_file.close()

    if opt == 'adam':
        optimizer = torch.optim.Adam(coma.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(coma.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    momentum=0.9)
    else:
        raise Exception('No optimizer provided')

    if args.load_checkpoint_dir:
        #to load the newest saved checkpoint
        to_back = os.getcwd()
        os.chdir(load_checkpoint_dir)
        chkpt_list = sorted(os.listdir(os.getcwd()), key=os.path.getctime)
        os.chdir(to_back)
        checkpoint_file = chkpt_list[-1]

        logfile = os.path.join(current_log_dir, 'loadedfrom.txt')
        my_data_file = open(logfile, 'w')
        my_data_file.write(str(load_checkpoint_dir))
        my_data_file.close()

        print(
            colored(
                '\n\nloading Newest checkpoint : {}\n'.format(checkpoint_file),
                'red'))
        if checkpoint_file:
            checkpoint = torch.load(
                os.path.join(load_checkpoint_dir, checkpoint_file))
            start_epoch = checkpoint['epoch_num']
            coma.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            #To find if this is fixed in pytorch
            for state in optimizer.state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(device)
    coma.to(device)

    for i, dt in enumerate(train_loader):
        dt = dt.to(device)
        graphstr = pms.summary(coma,
                               dt,
                               batch_size=-1,
                               show_input=True,
                               show_hierarchical=False)
        if print_model_summary:
            print(graphstr)

        print(colored('dt in enumerate(train_loader):{} '.format(dt), 'green'))
        #write network architecture into text file
        logfile = os.path.join(current_log_dir, 'pms.txt')
        my_data_file = open(logfile, 'w')
        my_data_file.write(graphstr)
        my_data_file.close()

        mrkdwn = str('<pre><code>' + graphstr + '</code></pre>')
        tbSummWriter.add_text('tag', mrkdwn, global_step=None, walltime=None)
        break  #for one sample only

    if eval_flag and args.load_checkpoint_dir:
        evaluatedFrom = 'predictedPlys_' + checkpoint_file
        output_dir = os.path.join('../Experiments/', args.load_checkpoint_dir,
                                  evaluatedFrom)  #load last checkpoint
        val_loss = evaluate(coma,
                            test_loader,
                            dataset_test,
                            template_mesh,
                            device,
                            visualize=True,
                            output_dir=output_dir)
        print('val loss', val_loss)
        return

    best_val_loss = float('inf')
    val_loss_history = []

    for epoch in range(start_epoch, total_epochs + 1):
        print("Training for epoch ", epoch)
        print('dataset.len : {}'.format(len(dataset)))

        train_loss = train(coma, train_loader, len(dataset), optimizer, device)
        val_loss = evaluate(coma,
                            test_loader,
                            dataset_test,
                            template_mesh,
                            device,
                            visualize=False,
                            output_dir='')  #train without visualization

        tbSummWriter.add_scalar('Loss/train', train_loss, epoch)
        tbSummWriter.add_scalar('Val Loss/train', val_loss, epoch)
        tbSummWriter.add_scalar('learning_rate', lr, epoch)

        print('epoch ', epoch, ' Train loss ', train_loss, ' Val loss ',
              val_loss)
        if val_loss < best_val_loss:
            save_model(coma, optimizer, epoch, train_loss, val_loss,
                       save_checkpoint_dir)
            best_val_loss = val_loss

        val_loss_history.append(val_loss)
        val_losses.append(best_val_loss)

        if opt == 'sgd':
            adjust_learning_rate(optimizer, lr_decay)

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    tbSummWriter.flush()
    tbSummWriter.close()