Exemplo n.º 1
0
def main():
    co_transform = pc_transforms.Compose([
        pc_transforms.ArrayToTensor(),
        transforms.Normalize(mean=[0.5, 0.5], std=[1, 1])
    ])

    input_transforms = transforms.Compose([
        pc_transforms.ArrayToTensor(),
        #   transforms.Normalize(mean=[0.5,0.5],std=[1,1])
    ])

    target_transforms = transforms.Compose([
        pc_transforms.ArrayToTensor(),
        #  transforms.Normalize(mean=[0.5, 0.5], std=[1, 1])
    ])
    """Data Loader"""
    #  x

    [train_dataset, valid_dataset
     ] = Datasets.__dict__[args.dataName](input_root=args.data,
                                          target_root=None,
                                          split=args.split_value,
                                          net_name='auto_encoder',
                                          input_transforms=input_transforms,
                                          target_transforms=target_transforms)
    input, target = train_dataset[1]

    omax = 0.0
    omin = 0.0
    for i, (input_train, target) in enumerate(train_dataset):
        nmax = torch.max(input_train)
        nmax = np.max([torch.Tensor.numpy(nmax), omax])
        omax = nmax

        nmin = torch.min(input_train)
        nmin = np.min([torch.Tensor.numpy(nmin), omin])
        omin = nmin
    # 0.499
    # - 0.499

    for i, (input_valid, target) in enumerate(valid_dataset):
        nmax = torch.max(input_valid)
        nmax = np.max([torch.Tensor.numpy(nmax), omax])
        omax = nmax

        nmin = torch.min(input_valid)
        nmin = np.min([torch.Tensor.numpy(nmin), omin])
        omin = nmin

    print('ZMAX:', omax)
    print('ZMIN:', omin)
Exemplo n.º 2
0
def main():
    """ Save Path """
    train_writer = None
    valid_writer = None
    test_writer = None

    if args.save == True:
        save_path = '{},{},{}epochs,b{},lr{}'.format(args.model, args.optim,
                                                     args.epochs,
                                                     args.batch_size, args.lr)
        time_stamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        save_path = os.path.join(time_stamp, save_path)
        save_path = os.path.join(args.dataName, save_path)
        save_path = os.path.join(args.save_path, save_path)
        print('==> Will save Everything to {}', save_path)

        if not os.path.exists(save_path):
            os.makedirs(save_path)

        #""" Setting for TensorboardX """

        train_writer = SummaryWriter(os.path.join(save_path, 'train'))
        valid_writer = SummaryWriter(os.path.join(save_path, 'valid'))
        test_writer = SummaryWriter(os.path.join(save_path, 'test'))
    # output_writer = SummaryWriter(os.path.join(save_path, 'Output_Writer'))
    """ Transforms/ Data Augmentation Tec """
    co_transforms = pc_transforms.Compose([
        #  pc_transforms.Delete(num_points=1466)
        # pc_transforms.Jitter_PC(sigma=0.01,clip=0.05),
        # pc_transforms.Scale(low=0.9,high=1.1),
        #  pc_transforms.Shift(low=-0.01,high=0.01),
        # pc_transforms.Random_Rotate(),
        #  pc_transforms.Random_Rotate_90(),

        # pc_transforms.Rotate_90(args,axis='x',angle=-1.0),# 1.0,2,3,4
        # pc_transforms.Rotate_90(args, axis='z', angle=2.0),
        # pc_transforms.Rotate_90(args, axis='y', angle=2.0),
        # pc_transforms.Rotate_90(args, axis='shape_complete') TODO this is essential for Angela's data set
    ])

    input_transforms = transforms.Compose([
        pc_transforms.ArrayToTensor(),
        #   transforms.Normalize(mean=[0.5,0.5],std=[1,1])
    ])

    target_transforms = transforms.Compose([
        pc_transforms.ArrayToTensor(),
        #  transforms.Normalize(mean=[0.5, 0.5], std=[1, 1])
    ])
    """-----------------------------------------------Data Loader----------------------------------------------------"""

    if (args.net_name == 'auto_encoder'):
        [train_dataset, valid_dataset] = Datasets.__dict__[args.dataName](
            input_root=args.data,
            target_root=None,
            split=args.split_value,
            net_name=args.net_name,
            input_transforms=input_transforms,
            target_transforms=target_transforms,
            co_transforms=co_transforms)
        [test_dataset, _] = Datasets.__dict__[args.dataName](
            input_root=args.datatest,
            target_root=None,
            split=None,
            net_name=args.net_name,
            input_transforms=input_transforms,
            target_transforms=target_transforms,
            co_transforms=co_transforms)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               shuffle=True,
                                               pin_memory=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=args.batch_size,
                                               num_workers=args.workers,
                                               shuffle=False,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              num_workers=args.workers,
                                              shuffle=False,
                                              pin_memory=True)
    """----------------------------------------------Model Settings--------------------------------------------------"""

    print('Model:', args.model)

    if args.pretrained:
        network_data = torch.load(args.pretrained)

        args.model = network_data['model']
        print("==> Using Pre-trained Model '{}' saved at {} ".format(
            args.model, args.pretrained))
    else:
        network_data = None

    if (args.model == 'ae_pointnet'):
        model = models.__dict__[args.model](args,
                                            num_points=2048,
                                            global_feat=True,
                                            data=network_data).cuda()
    else:
        model = models.__dict__[args.model](network_data).cuda()

#  model = torch.nn.DataParallel(model.cuda(),device_ids=[0,1]) TODO To make dataparallel run do Nigels Fix """https://github.com/pytorch/pytorch/issues/1637#issuecomment-338268158"""

    params = get_n_params(model)
    print('| Number of parameters [' + str(params) + ']...')
    """-----------------------------------------------Optimizer Settings---------------------------------------------"""

    cudnn.benchmark = True
    print('Settings {} Optimizer'.format(args.optim))

    # param_groups = [{'params': model.module.bias_parameters(), 'weight_decay': args.bias_decay},
    #                  {'params': model.module.weight_parameters(), 'weight_decay':args.weight_decay}
    #                  ]
    if args.optim == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     betas=(args.momentum, args.beta))
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones=args.milestones, gamma=args.gamma)
    # scheduler =  torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    """-------------------------------------------------Visualer Initialization-------------------------------------"""

    visualizer = Visualizer(args)

    args.display_id = args.display_id + 10
    args.name = 'Validation'
    vis_Valid = Visualizer(args)
    vis_Valida = []
    args.display_id = args.display_id + 10

    for i in range(1, 12):

        vis_Valida.append(Visualizer(args))
        args.display_id = args.display_id + 10
    """---------------------------------------------------Loss Setting-----------------------------------------------"""

    chamfer = ChamferLoss(args)

    best_loss = -1
    valid_loss = 1000

    if args.test_only == True:
        epoch = 0
        test_loss, _, _ = test(valid_loader, model, epoch, args, chamfer,
                               vis_Valid, vis_Valida, test_writer)
        test_writer.add_scalar('mean Loss', test_loss, epoch)

        print('Average Loss :{}'.format(test_loss))
    else:
        """------------------------------------------------Training and Validation-----------------------------------"""
        for epoch in range(args.start_epoch, args.epochs):

            scheduler.step()

            train_loss, _, _ = train(train_loader, model, optimizer, epoch,
                                     args, chamfer, visualizer, train_writer)
            train_writer.add_scalar('mean Loss', train_loss, epoch)

            valid_loss, _, _ = validation(test_loader, model, epoch, args,
                                          chamfer, vis_Valid, vis_Valida,
                                          valid_writer)
            valid_writer.add_scalar('mean Loss', valid_loss, epoch)

            if best_loss < 0:
                best_loss = valid_loss

            is_best = valid_loss < best_loss

            best_loss = min(valid_loss, best_loss)

            if args.save == True:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'model': args.model,
                        'state_dict': model.state_dict(
                        ),  # TODO if data parallel is fized write model.module.state_dict()
                        'state_dict_encoder': model.encoder.state_dict(),
                        'state_dict_decoder': model.decoder.state_dict(),
                        'best_loss': best_loss
                    },
                    is_best,
                    save_path)