def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--sc-root', type=str, required=True)
    parser.add_argument('--experiments-root', type=str, required=True)
    parser.add_argument('--run', type=str, required=True)
    parser.add_argument('--device', type=str)
    args = parser.parse_args()

    if not args.device:
        args.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

    common_init()

    experimenter = Experimenter(sc_root=args.sc_root,
                                experiments_root=args.experiments_root,
                                experiment_name=args.run,
                                device=args.device)

    if args.run == 'compare_ae_vae_source_label':
        experimenter.compare_ae_vae_source_label()
    elif args.run == 'vary_num_classes':
        experimenter.vary_num_classes()
    else:
        raise Exception('invalid experiment')
示例#2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--sc09-mixture-root', type=str, required=True)
    parser.add_argument('--checkpoints', type=str, required=True)
    parser.add_argument('--supervision', choices=['label', 'source'], default='label')
    parser.add_argument('--batch-size', type=int, default=100)
    parser.add_argument('--latent-size', type=int, default=128)
    parser.add_argument('--num-filters', type=int, default=128)
    parser.add_argument('--beta', type=float, default=10)
    parser.add_argument('--ae', action='store_true')
    parser.add_argument('--report-interval', type=int, default=200)
    parser.add_argument('--patience', type=int, default=10)
    parser.add_argument('--device', type=str)
    args = parser.parse_args()

    if not args.device:
        args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    common_init()

    nvae_trainer = NVAETrainer(args.sc09_mixture_root, args.device)
    nvae_trainer.run(checkpoints=args.checkpoints,
                     supervision=args.supervision,
                     batch_size=args.batch_size,
                     latent_size=args.latent_size,
                     num_filters=args.num_filters,
                     beta=args.beta,
                     ae=args.ae,
                     report_interval=args.report_interval,
                     patience=args.patience)
示例#3
0
    def setUp(self):
        common_init(self)

        self.a = 2.0
        self.b = 8.0
        self.xval = 4.0
        self.yval = 16.0
        self.overflow_buf = torch.cuda.IntTensor(1).zero_()
        self.ref = torch.cuda.FloatTensor([136.0])
    def setUp(self):
        common_init(self)

        self.a = 2.0
        self.b = 8.0
        self.xval = 4.0
        self.yval = 16.0
        self.overflow_buf = torch.cuda.IntTensor(1).zero_()
        self.ref = torch.full((1, ), 136.0, device="cuda", dtype=torch.float32)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--sc09-mixture-root', type=str, required=True)
    parser.add_argument('--checkpoints', type=str, required=True)
    parser.add_argument('--partition', type=str, default='testing')
    parser.add_argument('--size', type=int)
    parser.add_argument('--step', type=int)
    parser.add_argument('--device', type=str)
    args = parser.parse_args()

    if not args.device:
        args.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')

    common_init()

    nvae_tester = NVAETester(sc09_mixture_root=args.sc09_mixture_root,
                             partition=args.partition,
                             size=args.size,
                             device=args.device)
    nvae_tester.run(checkpoints=args.checkpoints, step=args.step)
示例#6
0
    def setUp(self):
        self.scale = 4.0
        self.overflow_buf = torch.cuda.IntTensor(1).zero_()
        self.ref = torch.cuda.FloatTensor([1.0])

        common_init(self)
示例#7
0
 def setUp(self):
     self.x = torch.ones((2), device='cuda', dtype=torch.float32)
     common_init(self)
示例#8
0
 def setUp(self):
     self.handle = amp.init(enabled=True, patch_type=torch.half)
     common_init(self)
示例#9
0
 def setUp(self):
     self.handle = amp.init(enabled=True)
     common_init(self)
示例#10
0
 def setUp(self):
     self.handle = amp.init(enabled=True)
     self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32)
     common_init(self)
 def setUp(self):
     common_init(self)
     self.val = 4.0
     self.overflow_buf = torch.cuda.IntTensor(1).zero_()