Exemplo n.º 1
0
def prepare_custom_dataset(path, test_set):
    ComaDataset(path, split='custom', split_term=test_set, pre_transform=Normalize())
Exemplo n.º 2
0
def prepare_gnrtdx_dataset(path):
    ComaDataset(path,
                split='gnrtdx',
                split_term='gnrtdx',
                pre_transform=Normalize())
Exemplo n.º 3
0
def prepare_sliced_dataset(path):
    ComaDataset(path, pre_transform=Normalize())
Exemplo n.º 4
0
def prepare_expression_dataset(path):
    test_exps = ['bareteeth', 'cheeks_in', 'eyebrow', 'high_smile', 'lips_back', 'lips_up', 'mouth_down',
                 'mouth_extreme', 'mouth_middle', 'mouth_open', 'mouth_side', 'mouth_up']
    for exp in test_exps:
        ComaDataset(path, split='expression', split_term=exp, pre_transform=Normalize())
Exemplo n.º 5
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    if args.checkpoint_dir:
        checkpoint_dir = args.checkpoint_dir
    else:
        checkpoint_dir = config['checkpoint_dir']
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    visualize = config['visualize']
    output_dir = config['visual_output_dir']
    if visualize is True and not output_dir:
        print(
            'No visual output directory is provided. Checkpoint directory will be used to store the visual results'
        )
        output_dir = checkpoint_dir

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    eval_flag = config['eval']
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config['workers_thread']
    opt = config['optimizer']
    batch_size = config['batch_size']
    val_losses, accs, durations = [], [], []

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    normalize_transform = Normalize()

    dataset = ComaDataset(data_dir,
                          dtype='train',
                          split=args.split,
                          split_term=args.split_term,
                          pre_transform=normalize_transform)
    print('Loading model')
    start_epoch = 1
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
    if opt == 'adam':
        optimizer = torch.optim.Adam(coma.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(coma.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    momentum=0.9)
    else:
        raise Exception('No optimizer provided')

    checkpoint_file = config['checkpoint_file']
    print(checkpoint_file)
    if checkpoint_file:
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch_num']
        coma.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        #To find if this is fixed in pytorch
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    coma.to(device)
    print('making...')
    norm = torch.load('../processed_data/processed/sliced_norm.pt')
    normalize_transform.mean = norm['mean']
    normalize_transform.std = norm['std']

    #'0512','0901','0516','0509','0507','9305','0503','4919','4902',
    files = [
        '0514', '0503', '0507', '0509', '0512', '0501', '0901', '1001', '4902',
        '4913', '4919', '9302', '9305', '12411'
    ]

    coma.eval()

    meshviewer = MeshViewers(shape=(1, 2))
    for file in files:
        #mat = np.load('../Dress Dataset/'+file+'/'+file+'_pose.npz')
        mesh_dir = os.listdir('../processed_data/' + file + '/mesh/')
        latent = []
        print(len(mesh_dir))
        for i in tqdm(range(len(mesh_dir))):
            data_file = '../processed_data/' + file + '/mesh/' + str(
                i) + '.obj'
            mesh = Mesh(filename=data_file)
            adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
            edge_index = torch.Tensor(np.vstack(
                (adjacency.row, adjacency.col))).long()
            mesh_verts = (torch.Tensor(mesh.v) -
                          normalize_transform.mean) / normalize_transform.std
            data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index)
            data = data.to(device)
            with torch.no_grad():
                out, feature = coma(data)
                latent.append(feature.cpu().detach().numpy())
            # print(feature.shape)
            if i % 50 == 0:
                expected_out = data.x
                out = out.cpu().detach(
                ) * normalize_transform.std + normalize_transform.mean
                expected_out = expected_out.cpu().detach(
                ) * normalize_transform.std + normalize_transform.mean
                out = out.numpy()
                save_obj(out, template_mesh.f + 1,
                         './vis/reconstruct_' + str(i) + '.obj')
                save_obj(expected_out, template_mesh.f + 1,
                         './vis/ori_' + str(i) + '.obj')

        np.save('./processed/0820/' + file, latent)

    if torch.cuda.is_available():
        torch.cuda.synchronize()
Exemplo n.º 6
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)
    for k in config.keys():
        print(k, config[k])

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    if args.checkpoint_dir:
        checkpoint_dir = args.checkpoint_dir
    else:
        checkpoint_dir = config['checkpoint_dir']
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    visualize = config['visualize']
    output_dir = config['visual_output_dir']
    if visualize is True and not output_dir:
        print(
            'No visual output directory is provided. Checkpoint directory will be used to store the visual results'
        )
        output_dir = checkpoint_dir

    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    eval_flag = config['eval']
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config['workers_thread']
    opt = config['optimizer']
    batch_size = config['batch_size']
    val_losses, accs, durations = [], [], []

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('current device: ', torch.cuda.current_device()
          ) if torch.cuda.is_available() else print('no cuda, just cpu')

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    normalize_transform = Normalize()
    dataset = ComaDataset(data_dir,
                          dtype='train',
                          split=args.split,
                          split_term=args.split_term,
                          pre_transform=normalize_transform)
    dataset_val = ComaDataset(data_dir,
                              dtype='val',
                              split=args.split,
                              split_term=args.split_term,
                              pre_transform=normalize_transform)
    dataset_test = ComaDataset(data_dir,
                               dtype='test',
                               split=args.split,
                               split_term=args.split_term,
                               pre_transform=normalize_transform)
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=workers_thread)
    val_loader = DataLoader(dataset_val,
                            batch_size=1,
                            shuffle=True,
                            num_workers=workers_thread)
    test_loader = DataLoader(dataset_test,
                             batch_size=1,
                             shuffle=False,
                             num_workers=workers_thread)

    print('Loading model')
    start_epoch = 1
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
    if opt == 'adam':
        optimizer = torch.optim.Adam(coma.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(coma.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    momentum=0.9)
    else:
        raise Exception('No optimizer provided')

    checkpoint_file = config['checkpoint_file']
    print(checkpoint_file)
    if checkpoint_file:
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch_num']
        coma.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        #To find if this is fixed in pytorch
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

    print(coma)
    coma.to(device)

    if eval_flag:
        val_loss = evaluate(coma, output_dir, test_loader, dataset_test,
                            template_mesh, device, visualize)
        print('val loss', val_loss)
        return

    best_val_loss = float('inf')
    val_loss_history = []

    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    log_dir = os.path.join('runs/cae_dxb', current_time)
    writer = SummaryWriter(log_dir + '-ds2_lr0.04_z2')

    for epoch in range(start_epoch, total_epochs + 1):
        print("Training for epoch ", epoch)
        train_loss = train(coma, train_loader, len(dataset), optimizer, device)
        val_loss = evaluate(coma,
                            output_dir,
                            val_loader,
                            dataset_val,
                            template_mesh,
                            device,
                            epoch,
                            visualize=visualize)

        writer.add_scalar('data/train_loss', train_loss, epoch)
        writer.add_scalar('data/val_loss', val_loss, epoch)

        print('epoch ', epoch, ' Train loss ', train_loss, ' Val loss ',
              val_loss)
        if val_loss < best_val_loss:
            save_model(coma, optimizer, epoch, train_loss, val_loss,
                       checkpoint_dir)
            best_val_loss = val_loss

        if epoch == total_epochs or epoch % 100 == 0:
            save_model(coma, optimizer, epoch, train_loss, val_loss,
                       checkpoint_dir)

        val_loss_history.append(val_loss)
        val_losses.append(best_val_loss)

        if opt == 'sgd':
            adjust_learning_rate(optimizer, lr_decay)

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    writer.close()
    print(coma)
Exemplo n.º 7
0
        return ims

    def get_multiple(self, ):
        pass


if __name__ == '__main__':
    from transform import (Compose, Normalize, Scale, CenterCrop, CornerCrop,
                           MultiScaleCornerCrop, MultiScaleRandomCrop,
                           RandomHorizontalFlip, ToTensor)
    D = EPIC_KITCHENS(
        '/mnt/nisho_data2/hyf/EPIC-annotations/EPIC_train_action_labels.csv',
        transform=Compose([
            Scale([224, 224]),
            ToTensor(255),
            Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]))
    loader = DataLoader(
        dataset=D,
        batch_size=2,
        shuffle=False,
        num_workers=8,
        pin_memory=True,
    )
    print(len(loader))
    from tqdm import tqdm
    for i, sample in tqdm(enumerate(loader)):
        pass
        # print(sample['ims'].size())  #(b, 3, cliplen, 224, 224)
        # print(sample['vid'])  #['P01_01', 'P01_01']
        # print(sample['verb'].size())  #(b, 1)
    plt.show()


#=================================================================
csv_dir_train = 'data/training_frames_keypoints.csv'
root_dir_train = 'data/training/'
csv_dir_test = 'data/test_frames_keypoints.csv'
root_dir_test = 'data/test/'

batch_size = 8

transform = transforms.Compose([
    RandomRotation(15),
    Rescale(224),
    RandomCrop(223),
    Normalize(),
    ToTensor()
])

trainDataset = FacialDataset(csv_dir_train, root_dir_train, transform)
testDataset = FacialDataset(csv_dir_test, root_dir_test, transform)
'''

sample = trainDataset[0]
angle = 22.5

sample_x = transform(sample)
display(sample_x['image'], sample_x['keypoints'])

print(sample['keypoints'][0][1])
'''
Exemplo n.º 9
0
                            optimizer_G, losses)


def set_logging(name):
    log_path = Path('.') / 'log'
    if not (log_path.exists() and log_path.is_dir()):
        log_path.mkdir(parents=True)
    logging.basicConfig(level=logging.INFO,
                        format='%(message)s',
                        handlers=[
                            logging.StreamHandler(),
                            logging.FileHandler(
                                os.path.join('log', '{}.log'.format(name)))
                        ])


if __name__ == "__main__":

    set_logging(cfg.name)

    ds = RaindropDataset(cfg.train_path,
                         transform=Compose([
                             RandomCrop(480),
                             RandomHorizontalFlip(),
                             ToTensor(),
                             Normalize(cfg.mean, cfg.std)
                         ]))
    loader = DataLoader(ds, 10, shuffle=True, num_workers=4)
    train(cfg.name, loader, True, cfg.num_rep, cfg.lr, cfg.beta1,
          cfg.gamma_gan, cfg.num_epoch, cfg.wd, cfg.device)
Exemplo n.º 10
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)
    print(template_file_path)

    if args.checkpoint_dir:
        checkpoint_dir = args.checkpoint_dir
        print(os.path.exists(checkpoint_dir))

    else:
        checkpoint_dir = config['checkpoint_dir']
        print(os.path.exists(checkpoint_dir))
    # if not os.path.exists(checkpoint_dir):
    #     os.makedirs(checkpoint_dir)

    visualize = config['visualize']
    output_dir = config['visual_output_dir']
    if visualize is True and not output_dir:
        print(
            'No visual output directory is provided. Checkpoint directory will be used to store the visual results'
        )
        output_dir = checkpoint_dir

    # if not os.path.exists(output_dir):
    #     os.makedirs(output_dir)

    eval_flag = config['eval']
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config['workers_thread']
    opt = config['optimizer']
    batch_size = config['batch_size']
    val_losses, accs, durations = [], [], []

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if torch.cuda.is_available():
        print('\ncuda is available...\n')
    else:
        print('\ncuda is NOT available...\n')
    # device = 'cpu'

    # print('Generating transforms')
    # M, A, D, U = mesh_operations.generate_transform_matrices(template_mesh, config['downsampling_factors'])

    # D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    # U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    # A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    # num_nodes = [len(M[i].v) for i in range(len(M))]

    print('\n*** Loading Dataset ***\n')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    print(data_dir)
    normalize_transform = Normalize()
    # normalize_transform = MinMaxScaler()
    dataset = ComaDataset(data_dir,
                          dtype='train',
                          split=args.split,
                          split_term=args.split_term)
    dataset_test = ComaDataset(data_dir,
                               dtype='test',
                               split=args.split,
                               split_term=args.split_term,
                               pre_transform=normalize_transform)

    # dataset = FcadDataset(data_dir, dtype='train', transform=T.NormalizeScale())

    print('Done ......... \n')

    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=workers_thread)
    test_loader = DataLoader(dataset_test,
                             batch_size=1,
                             shuffle=False,
                             num_workers=workers_thread)
Exemplo n.º 11
0
def main(checkpoint, config_path, output_dir):
    config = read_config(config_path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Preparing dataset')
    data_dir = config['data_dir']
    normalize_transform = Normalize()
    dataset = ComaDataset(data_dir,
                          dtype='test',
                          split='sliced',
                          split_term='sliced',
                          pre_transform=normalize_transform)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

    print('Loading model')
    model = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
    checkpoint = torch.load(checkpoint)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    model.eval()
    model.to(device)

    print('Generating latent')
    data = next(iter(loader))
    with torch.no_grad():
        data = data.to(device)
        x = data.x.reshape(data.num_graphs, -1, model.filters[0])
        z = model.encoder(x)

    print('View meshes')
    meshviewer = MeshViewers(shape=(1, 1))
    for feature_index in range(z.size(1)):
        j = torch.range(-4, 4, step=0.1, device=device)
        new_z = z.expand(j.size(0), z.size(1)).clone()
        new_z[:, feature_index] *= 1 + 0.3 * j

        with torch.no_grad():
            out = model.decoder(new_z)
            out = out.detach().cpu() * dataset.std + dataset.mean

        for i in trange(out.shape[0]):
            mesh = Mesh(v=out[i], f=template_mesh.f)
            meshviewer[0][0].set_dynamic_meshes([mesh])

            f = os.path.join(output_dir, 'z{}'.format(feature_index),
                             '{:04d}.png'.format(i))
            os.makedirs(os.path.dirname(f), exist_ok=True)
            meshviewer[0][0].save_snapshot(f, blocking=True)
Exemplo n.º 12
0
if __name__ == '__main__':
    """
    test VOCDataset
    """
    import numpy as np
    import matplotlib.pyplot as plt
    from transform import Normalize, RandomScaleCrop, RandomGaussianBlur, RandomHorizontalFlip, ToTensor
    from torchvision import transforms
    from torch.utils.data import DataLoader
    from utils import utils

    trainset = VOCDataset("data/train",
                          transform=transforms.Compose([
                              RandomScaleCrop(550, 512),
                              RandomHorizontalFlip(),
                              Normalize(mean=(0.485, 0.456, 0.406),
                                        std=(0.229, 0.224, 0.225)),
                              ToTensor()
                          ]))
    dataloader = DataLoader(trainset,
                            batch_size=2,
                            shuffle=True,
                            num_workers=0)

    for i, sample in enumerate(dataloader):
        for j in range(sample["image"].size()[0]):
            image = sample["image"][j].numpy()
            mask = sample["mask"][j].numpy()
            image = image.transpose([1, 2, 0])
            image *= (0.229, 0.224, 0.225)
            image += (0.485, 0.456, 0.406)
            image = image * 255
Exemplo n.º 13
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    if args.checkpoint_dir:
        checkpoint_dir = args.checkpoint_dir
    else:
        checkpoint_dir = config['checkpoint_dir']
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    visualize = config['visualize']
    output_dir = config['visual_output_dir']
    if visualize is True and not output_dir:
        print(
            'No visual output directory is provided. Checkpoint directory will be used to store the visual results'
        )
        output_dir = checkpoint_dir

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    eval_flag = config['eval']
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config['workers_thread']
    opt = config['optimizer']
    batch_size = config['batch_size']
    val_losses, accs, durations = [], [], []

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])
    print(len(M))

    for i in range(len(M)):
        print(M[i].v.shape)
    print('************A****************')
    for a in A:
        print(a.shape)
    print('************D****************')
    for d in D:
        print(d.shape)
    print('************U****************')
    for u in U:
        print(u.shape)

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    normalize_transform = Normalize()
    dataset = ComaDataset(data_dir,
                          dtype='train',
                          split=args.split,
                          split_term=args.split_term,
                          pre_transform=normalize_transform)
    dataset_test = ComaDataset(data_dir,
                               dtype='test',
                               split=args.split,
                               split_term=args.split_term,
                               pre_transform=normalize_transform)
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=workers_thread)
    test_loader = DataLoader(dataset_test,
                             batch_size=1,
                             shuffle=False,
                             num_workers=workers_thread)

    print('Loading model')
    start_epoch = 1
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
    if opt == 'adam':
        optimizer = torch.optim.Adam(coma.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(coma.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    momentum=0.9)
    else:
        raise Exception('No optimizer provided')

    checkpoint_file = config['checkpoint_file']

    if checkpoint_file:
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch_num']
        coma.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        #To find if this is fixed in pytorch
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    coma.to(device)

    if eval_flag:
        val_loss = evaluate(coma, output_dir, test_loader, dataset_test,
                            template_mesh, device, visualize)
        print('val loss', val_loss)
        return

    best_val_loss = float('inf')
    val_loss_history = []
    train_loss_history = []

    for epoch in range(start_epoch, total_epochs + 1):
        print("Training for epoch ", epoch)
        train_loss = train(coma, train_loader, len(dataset), optimizer, device)
        val_loss = evaluate(coma,
                            output_dir,
                            test_loader,
                            dataset_test,
                            template_mesh,
                            device,
                            visualize=visualize)

        val_loss_history.append(val_loss)
        train_loss_history.append(train_loss)

        print('epoch ', epoch, ' Train loss ', train_loss, ' Val loss ',
              val_loss)
        if val_loss < best_val_loss:
            save_model(coma, optimizer, epoch, train_loss, val_loss,
                       checkpoint_dir)
            best_val_loss = val_loss
            val_losses.append(best_val_loss)

        if opt == 'sgd':
            adjust_learning_rate(optimizer, lr_decay)

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    times = list(range(len(train_loss_history)))

    fig = plt.figure()
    ax = fig.add_subplot(111)
    ax.plot(times, train_loss_history)
    ax.plot(times, val_loss_history)
    ax.set_xlabel("iteration")
    ax.set_ylabel(" loss")
    plt.savefig(checkpoint_dir + 'result.png')
Exemplo n.º 14
0
def prepare_lgtdvc_dataset(path):
    ComaDataset(path,
                split='lgtdvc',
                split_term='lgtdvc',
                pre_transform=Normalize())
Exemplo n.º 15
0
def main():
    parser = ArgumentParser()
    parser.add_argument('-d',
                        '--data_path',
                        dest='data_path',
                        type=str,
                        default=None,
                        help='path to the data')
    parser.add_argument('--epochs',
                        '-e',
                        dest='epochs',
                        type=int,
                        help='number of train epochs',
                        default=100)
    parser.add_argument('--batch_size',
                        '-b',
                        dest='batch_size',
                        type=int,
                        help='batch size',
                        default=128)  # 1o024
    parser.add_argument('--weight_decay',
                        '-wd',
                        dest='weight_decay',
                        type=float,
                        help='weight_decay',
                        default=5e-4)
    parser.add_argument('--lr',
                        '-lr',
                        dest='lr',
                        type=float,
                        help='lr',
                        default=1e-4)
    parser.add_argument('--lr_step',
                        '-lrs',
                        dest='lr_step',
                        type=int,
                        help='lr step',
                        default=None)
    parser.add_argument('--lr_gamma',
                        '-lrg',
                        dest='lr_gamma',
                        type=float,
                        help='lr gamma factor',
                        default=None)
    parser.add_argument('--input_wh',
                        '-wh',
                        dest='input_wh',
                        type=str,
                        help='model input size',
                        default='320x64')
    parser.add_argument('--rnn_dropout',
                        '-rdo',
                        dest='rnn_dropout',
                        type=float,
                        help='rnn dropout p',
                        default=0.1)
    parser.add_argument('--rnn_num_directions',
                        '-rnd',
                        dest='rnn_num_directions',
                        type=int,
                        help='bi',
                        default=1)
    parser.add_argument('--augs',
                        '-a',
                        dest='augs',
                        type=float,
                        help='degree of geometric augs',
                        default=0)
    parser.add_argument('--load',
                        '-l',
                        dest='load',
                        type=str,
                        help='pretrained weights',
                        default=None)
    parser.add_argument('-v',
                        '--val_split',
                        dest='val_split',
                        type=float,
                        default=0.8,
                        help='train/val split')
    parser.add_argument('-o',
                        '--output_dir',
                        dest='output_dir',
                        default='/tmp/logs_rec/',
                        help='dir to save log and models')
    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    logger = get_logger(os.path.join(args.output_dir, 'train.log'))
    logger.info('Start training with params:')
    for arg, value in sorted(vars(args).items()):
        logger.info("Argument %s: %r", arg, value)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    net = RecognitionModel(rnn_dropout=args.rnn_dropout,
                           num_directions=args.rnn_num_directions)
    if args.load is not None:
        net.load_state_dict(torch.load(args.load))
    net = net.to(device)
    criterion = ctc_loss
    logger.info('Model type: {}'.format(net.__class__.__name__))

    # TODO: try other optimizers and schedulers
    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    # scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) if args.lr_step is not None else None
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     patience=10,
                                                     factor=args.lr_gamma,
                                                     verbose=True)

    # dataset
    w, h = list(map(int, args.input_wh.split('x')))
    # TODO: again, augmentations is the key for many tasks
    train_transforms = Compose([
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    # val_transforms = Resize(size=(w, h))

    # TODO: don't forget to work on data cleansing
    train_dataset = RecognitionDataset(args.data_path,
                                       abc=abc,
                                       transforms=train_transforms)
    val_dataset = RecognitionDataset(args.data_path, abc=abc)
    # split dataset into train/val, don't try to do this at home ;)
    train_dataset.image_names, val_dataset.image_names, train_dataset.texts, val_dataset.texts = train_test_split(
        train_dataset.image_names,
        train_dataset.texts,
        test_size=1 - args.val_split,
        random_state=42)

    # TODO: maybe implement batch_sampler for tackling imbalance, which is obviously huge in many respects
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=8,
                                  collate_fn=train_dataset.collate_fn)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=8,
                                collate_fn=val_dataset.collate_fn)
    logger.info('Length of train/val=%d/%d', len(train_dataset),
                len(val_dataset))
    logger.info('Number of batches of train/val=%d/%d', len(train_dataloader),
                len(val_dataloader))

    try:
        train(net,
              criterion,
              optimizer,
              scheduler,
              train_dataloader,
              val_dataloader,
              args=args,
              logger=logger,
              device=device)
    except KeyboardInterrupt:
        torch.save(net.state_dict(),
                   os.path.join(args.output_dir, 'INTERRUPTED.pth'))
        logger.info('Saved interrupt')
        sys.exit(0)
Exemplo n.º 16
0
    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    normalize_transform = Normalize()
    dataset = ComaDataset(data_dir,
                          dtype='test',
                          split=args.split,
                          split_term=args.split_term,
                          pre_transform=normalize_transform)
    data_loader = DataLoader(dataset,
                             batch_size=1,
                             shuffle=False,
                             num_workers=1)

    print('Loading model')
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)

    checkpoint_file = config['checkpoint_file']
    print(checkpoint_file)
Exemplo n.º 17
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.set_num_threads(args.num_threads)
    if args.rep_cudnn:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    if args.checkpoint_dir:
        checkpoint_dir = args.checkpoint_dir
    else:
        checkpoint_dir = config['checkpoint_dir']
    checkpoint_dir = os.path.join(checkpoint_dir, args.modelname)
    print(datetime.datetime.now())
    print('checkpoint_dir', checkpoint_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    visualize = config[
        'visualize'] if args.visualize is None else args.visualize
    output_dir = config['visual_output_dir']
    if output_dir:
        output_dir = os.path.join(output_dir, args.modelname)
    if visualize is True and not output_dir:
        print('No visual output directory is provided. \
        Checkpoint directory will be used to store the visual results')
        output_dir = checkpoint_dir

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if not args.train:
        eval_flag = True
    else:
        eval_flag = config['eval']

    if args.learning_rate:
        config['learning_rate'] = args.learning_rate
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config[
        'workers_thread'] if args.num_workers is None else args.num_workers
    opt = config['optimizer']
    batch_size = config['batch_size'] if args.batch is None else args.batch
    val_losses, accs, durations = [], [], []

    if args.device_idx is None:
        device = torch.device(
            "cuda:" +
            str(config['device_idx']) if torch.cuda.is_available() else "cpu")
    elif args.device_idx >= 0:
        device = torch.device(
            "cuda:" +
            str(args.device_idx) if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    print(config)

    ds_fname = os.path.join(
        './template/',
        data_dir.split('/')[-1] + '_' + args.hier_matrices + '.pkl')
    if not os.path.exists(ds_fname):
        print("Generating Transform Matrices ..")
        M, A, D, U = mesh_operations.generate_transform_matrices(
            template_mesh, config['downsampling_factors'])
        with open(ds_fname, 'wb') as fp:
            M_verts_faces = [(M[i].v, M[i].f) for i in range(len(M))]
            pickle.dump(
                {
                    'M_verts_faces': M_verts_faces,
                    'A': A,
                    'D': D,
                    'U': U
                }, fp)
    else:
        print("Loading Transform Matrices ..")
        with open(ds_fname, 'rb') as fp:
            downsampling_matrices = pickle.load(fp)

        M_verts_faces = downsampling_matrices['M_verts_faces']
        M = [
            Mesh(v=M_verts_faces[i][0], f=M_verts_faces[i][1])
            for i in range(len(M_verts_faces))
        ]
        A = downsampling_matrices['A']
        D = downsampling_matrices['D']
        U = downsampling_matrices['U']

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    nV_ref = []
    ref_mean = np.mean(M[0].v, axis=0)
    ref_std = np.std(M[0].v, axis=0)
    for i in range(len(M)):
        nv = 0.1 * (M[i].v - ref_mean) / ref_std
        nV_ref.append(nv)

    tV_ref = [torch.from_numpy(s).float().to(device) for s in nV_ref]

    print('Loading Dataset')

    normalize_transform = Normalize()
    dataset = ComaDataset(data_dir,
                          dtype='train',
                          split=args.split,
                          split_term=args.split_term,
                          pre_transform=normalize_transform)
    dataset_val = ComaDataset(data_dir,
                              dtype='val',
                              split=args.split,
                              split_term=args.split_term,
                              pre_transform=normalize_transform)
    dataset_test = ComaDataset(data_dir,
                               dtype='test',
                               split=args.split,
                               split_term=args.split_term,
                               pre_transform=normalize_transform)

    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=workers_thread)
    val_loader = DataLoader(dataset_val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=workers_thread)
    test_loader = DataLoader(dataset_test,
                             batch_size=1 if visualize else batch_size,
                             shuffle=False,
                             num_workers=workers_thread)

    print('Loading model')
    start_epoch = 1
    if args.modelname in {'ComaAtt'}:
        gcn_model = eval(args.modelname)(dataset, config, D_t, U_t, A_t,
                                         num_nodes, tV_ref)
        gcn_params = gcn_model.parameters()
    else:
        gcn_model = eval(args.modelname)(dataset, config, D_t, U_t, A_t,
                                         num_nodes)
        gcn_params = gcn_model.parameters()

    params = sum(p.numel() for p in gcn_model.parameters() if p.requires_grad)
    print("Total number of parameters is: {}".format(params))
    print(gcn_model)

    if opt == 'adam':
        optimizer = torch.optim.Adam(gcn_params,
                                     lr=lr,
                                     weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(gcn_params,
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    momentum=0.9)
    else:
        raise Exception('No optimizer provided')

    if args.checkpoint_file:
        checkpoint_file = os.path.join(checkpoint_dir,
                                       str(args.checkpoint_file) + '.pt')
    else:
        checkpoint_file = config['checkpoint_file']
    if eval_flag and not checkpoint_file:
        checkpoint_file = os.path.join(checkpoint_dir, 'checkpoint.pt')

    print(checkpoint_file)
    if checkpoint_file:
        print('Loading checkpoint file: {}.'.format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file, map_location=device)
        start_epoch = checkpoint['epoch_num']
        gcn_model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)

    gcn_model.to(device)

    if eval_flag:
        val_loss, euclidean_loss = evaluate(gcn_model, output_dir, test_loader,
                                            dataset_test, template_mesh,
                                            device, visualize)
        print('val loss', val_loss)
        print('euclidean error is {} mm'.format(1000 * euclidean_loss))
        return

    best_val_loss = float('inf')
    val_loss_history = []

    for epoch in range(start_epoch, total_epochs + 1):
        print("Training for epoch ", epoch)
        train_loss = train(gcn_model, train_loader, len(dataset), optimizer,
                           device)
        val_loss, _ = evaluate(gcn_model,
                               output_dir,
                               val_loader,
                               dataset_val,
                               template_mesh,
                               device,
                               visualize=visualize)

        print('epoch {}, Train loss {:.8f}, Val loss {:.8f}'.format(
            epoch, train_loss, val_loss))
        if val_loss < best_val_loss:
            save_model(gcn_model, optimizer, epoch, train_loss, val_loss,
                       checkpoint_dir)
            best_val_loss = val_loss

        val_loss_history.append(val_loss)
        val_losses.append(best_val_loss)

        if opt == 'sgd':
            adjust_learning_rate(optimizer, lr_decay)

        if epoch in args.epochs_eval or (val_loss <= best_val_loss and
                                         epoch > int(total_epochs * 3 / 4)):
            val_loss, euclidean_loss = evaluate(gcn_model, output_dir,
                                                test_loader, dataset_test,
                                                template_mesh, device,
                                                visualize)
            print('epoch {} with val loss {}'.format(epoch, val_loss))
            print('euclidean error is {} mm'.format(1000 * euclidean_loss))

    if torch.cuda.is_available():
        torch.cuda.synchronize()