コード例 #1
0
 def __init__(self):
     """ 初始化 模型、预处理器 """
     torch.set_num_threads(1)
     torch.set_flush_denormal(True)
     self.model = torch.jit.load(self.MODEL_WEIGHT_PATH)
     self.model.eval()
     self.preprocess = transforms.Compose([
         Rescale(self.IMAGE_SIZE),
         partial(cv2.cvtColor, code=cv2.COLOR_BGR2RGB),
         transforms.ToTensor(),
         transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
     ])
コード例 #2
0
    def RenderVideo(self, renderVideoData):
        self.myRenderData = renderVideoData

        torch.set_flush_denormal(True)

        setting.AddCounter("RenderVideo")

        fps = self.myRenderData.fps

        self.didSplitterPrinted = False

        if self.myRenderData.mute_ffmpeg == 1:
            self.panic = "-hide_banner -loglevel panic"

        self.SetFolders(self.myRenderData)

        torch.cuda.set_device(self.myRenderData.sel_process)

        self.myRenderData.optimizer = 0

        print("Use Half is: " + str(renderVideoData.use_half))

        if renderVideoData.useBenchmark == 1:
            torch.backends.cudnn.benchmark = True
        else:
            torch.backends.cudnn.benchmark = False

        if self.myRenderData.fillMissingOriginal == 1:
            self.FillMissingOriginalFrames()
            return

        if self.myRenderData.doOriginal:
            self.StepExtractFrames(self.myRenderData)

        if self.myRenderData.use_half:
            torch.set_default_tensor_type(torch.HalfTensor)

        with torch.cuda.amp.autocast(bool(self.myRenderData.use_half)):
            if self.myRenderData.doIntepolation:
                self.model = Configure(self, self.myRenderData)
                self.StepRenderInterpolation(self.myRenderData)

        if self.myRenderData.doVideo:
            self.StepCreateVideo(self.myRenderData)

        if self.myRenderData.uploadBar != None:
            self.myRenderData.uploadBar(1)
コード例 #3
0
def env_config(GPUTrue, deviceName):

    global_seed = 2

    # Disable debug mode
    #torch.backends.cudnn.enabled=False
    torch.autograd.set_detect_anomaly(False)

    # Shrink very small values to zero in tensors for computational speedup
    torch.set_flush_denormal(True)

    # Set seed for random number generation (for reproducibility of results)
    torch.manual_seed(global_seed)
    torch.cuda.manual_seed(global_seed)
    np.random.seed(global_seed)

    # Set device as GPU if available, otherwise default to CPU
    if (GPUTrue):
        device = torch.device(
            deviceName if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    return device, global_seed
コード例 #4
0
            maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
        )

        return [make_result(batch.srcs[i], batch.words[i], t) for i, t in enumerate(translations)]

    if args.buffer_size > 1:
        print('| Sentence buffer size:', args.buffer_size)
    print('| Type the input sentence and press return:')
    for inputs in buffered_read(args.buffer_size):
        indices = []
        results = []
        for batch, batch_indices, raw_batch in make_batches(inputs, args, src_dict, models[0].max_positions()):
            indices.extend(batch_indices)
            results += process_batch(batch, raw_batch)

        for i in np.argsort(indices):
            result = results[i]
            print(result.src_str)
            for hypo, align in zip(result.hypos, result.alignments):
                print(hypo)
                print(align)


if __name__ == '__main__':
    torch.set_printoptions(4)
    torch.set_flush_denormal(True)

    parser = options.get_generation_parser(interactive=True)
    args = options.parse_args_and_arch(parser)
    main(args)
コード例 #5
0
def main_worker(gpu, args):
    """
    模型训练、测试、转JIT、蒸馏文件制作
    :param gpu: 运行的gpu id
    :param args: 运行超参
    """
    args.gpu = gpu
    utils.generate_logger(f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{gpu}.log")
    logging.info(f'args: {args}')

    # 可复现性
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        logging.warning('You have chosen to seed training. '
                        'This will turn on the CUDNN deterministic setting, '
                        'which can slow down your training considerably! '
                        'You may see unexpected behavior when restarting '
                        'from checkpoints.')

    if args.cuda:
        logging.info(f"Use GPU: {args.gpu} ~")
        if args.distributed:
            args.rank = args.rank * args.gpus + gpu
            dist.init_process_group(backend='nccl', init_method=args.init_method,
                                    world_size=args.world_size, rank=args.rank)
    else:
        logging.info(f"Use CPU ~")

    # 创建/加载模型,使用预训练模型时,需要自己先下载好放到 pretrained 文件夹下,以网络名词命名
    logging.info(f"=> creating model '{args.arch}'")
    model = my_models.get_model(args.arch, args.pretrained, num_classes=args.num_classes)

    # 重加载之前训练好的模型
    if args.resume:
        if os.path.isfile(args.resume):
            logging.info(f"=> loading checkpoint '{args.resume}'")
            checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
            acc = model.load_state_dict(checkpoint['state_dict'], strict=True)
            logging.info(f'missing keys of models: {acc.missing_keys}')
            del checkpoint
        else:
            raise Exception(f"No checkpoint found at '{args.resume}' to be resumed")

    # 模型信息
    image_height, image_width = args.image_size
    logging.info(f'Model {args.arch} input size: ({image_height}, {image_width})')
    utils.summary(size=(image_height, image_width), channel=3, model=model)

    # 模型转换:转为 torch.jit.script
    if args.jit:
        if not args.resume:
            raise Exception('Option --resume must specified!')
        applications.convert_to_jit(model, args=args)
        return

    if args.criterion == 'softmax':
        criterion = criterions.HybridCELoss(args=args)  # 混合策略多分类
    elif args.criterion == 'bce':
        criterion = criterions.HybridBCELoss(args=args)  # 混合策略多标签二分类
    else:
        raise NotImplementedError(f'Not loss function {args.criterion}')

    if args.cuda:
        if args.distributed and args.sync_bn:
            model = apex.parallel.convert_syncbn_model(model)
        torch.cuda.set_device(args.gpu)
        model.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)

    if args.knowledge in ('train', 'test', 'val'):
        torch.set_flush_denormal(True)
        distill_loader = dataloader.load(args, name=args.knowledge)
        applications.distill(distill_loader, model, criterion, args, is_confuse_matrix=True)
        return

    if args.make_curriculum in ('train', 'test', 'val'):
        torch.set_flush_denormal(True)
        curriculum_loader = dataloader.load(args, name=args.make_curriculum)
        applications.make_curriculum(curriculum_loader, model, criterion, args, is_confuse_matrix=True)
        return

    if args.visual_data in ('train', 'test', 'val'):
        torch.set_flush_denormal(True)
        test_loader = dataloader.load(args, name=args.visual_data)
        applications.Visualize.visualize(test_loader, model, args)
        return

    # 优化器
    opt_set = {
        'sgd': partial(torch.optim.SGD, momentum=args.momentum),
        'adam': torch.optim.Adam, 'adamw': AdamW,
        'radam': RAdam, 'ranger': Ranger, 'lookaheadadam': LookaheadAdam,
        'ralamb': Ralamb, 'rangerlars': RangerLars,
        'novograd': Novograd,
    }
    optimizer = opt_set[args.opt](model.parameters(), lr=args.lr)  # weight decay转移到train那里了
    # 随机均值平均优化器
    # from optim.swa import SWA
    # optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)

    # 混合精度训练
    if args.cuda:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if args.distributed:
        model = apex.parallel.DistributedDataParallel(model)
    else:
        model = torch.nn.DataParallel(model)

    if args.train:
        train_loader = dataloader.load(args, 'train')
        val_loader = dataloader.load(args, 'val')
        scheduler = LambdaLR(optimizer,
                             lambda epoch: adjust_learning_rate(epoch, args=args))
        applications.train(train_loader, val_loader, model, criterion, optimizer, scheduler, args)
        args.evaluate = True

    if args.evaluate:
        torch.set_flush_denormal(True)
        test_loader = dataloader.load(args, name='test')
        acc, loss, paths_targets_preds_probs = applications.test(test_loader, model,
                                                                 criterion, args, is_confuse_matrix=True)
        logging.info(f'Evaluation: * Acc@1 {acc:.3f} and loss {loss:.3f}.')
        logging.info(f'Evaluation Result:\n')
        for path, target, pred, prob in paths_targets_preds_probs:
            logging.info(path + ' ' + str(target) + ' ' + str(pred) + ' ' + ','.join([f'{num:.2f}' for num in prob]))
        logging.info('Evaluation Over~')
コード例 #6
0
ファイル: main.py プロジェクト: IntelAI/models
def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        os.environ['RANK'] = str(os.environ.get('PMI_RANK', args.rank))
        os.environ['WORLD_SIZE'] = str(
            os.environ.get('PMI_SIZE', args.world_size))
        os.environ['MASTER_ADDR'] = args.master_addr
        os.environ['MASTER_PORT'] = args.port
        if args.dist_url == "env://" and args.rank == -1:
            args.rank = int(os.environ["RANK"])
        if args.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            args.rank = args.rank * ngpus_per_node + gpu

        # Initialize the process group with ccl backend
        if args.dist_backend == 'ccl':
            import torch_ccl
        dist.init_process_group(backend=args.dist_backend)
        #dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
        #                        world_size=args.world_size, rank=args.rank)
    if args.hub:
        torch.set_flush_denormal(True)
        model = torch.hub.load('facebookresearch/WSL-Images', args.arch)
    else:
        # create model
        if args.pretrained:
            print("=> using pre-trained model '{}'".format(args.arch))
            model = models.__dict__[args.arch](pretrained=True)
        else:
            print("=> creating model '{}'".format(args.arch))
            model = models.__dict__[args.arch]()

    if args.ipex:
        import intel_extension_for_pytorch as ipex
    # for ipex path, always convert model to channels_last for bf16, fp32.
    # TODO: int8 path: https://jira.devtools.intel.com/browse/MFDNN-6103
    if args.ipex and not args.int8:
        model = model.to(memory_format=torch.channels_last)

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if args.gpu is not None and args.cuda:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int(
                (args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.gpu])
        else:
            if args.cuda:
                model.cuda()
                print("create DistributedDataParallel in GPU")
                # DistributedDataParallel will divide and allocate batch_size to all
                # available GPUs if device_ids are not set
                model = torch.nn.parallel.DistributedDataParallel(model)
            else:
                print("create DistributedDataParallel in CPU")
    elif args.gpu is not None and args.cuda:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        # DataParallel will divide and allocate batch_size to all available GPUs
        if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
            model.features = torch.nn.DataParallel(model.features)
            if args.cuda:
                model.cuda()
        else:
            model = torch.nn.DataParallel(model)
            if args.cuda():
                model.cuda()

    # define loss function (criterion) and optimizer

    criterion = nn.CrossEntropyLoss()
    if args.cuda:
        criterion = criterion.cuda(args.gpu)

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None and args.cuda:
                checkpoint = torch.load(args.resume)
            else:
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None and args.cuda:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    if args.cuda:
        cudnn.benchmark = True

    if args.weight_sharing:
        assert args.dummy and args.batch_size, \
                "please using dummy data and set batch_size to 1 if you want run weight sharing case for latency case"
    if args.jit and args.int8:
        assert False, "jit path is not available for int8 path using ipex"
    if args.calibration:
        assert args.int8, "please enable int8 path if you want to do int8 calibration path"
    if args.dummy:
        assert args.evaluate, "please using real dataset if you want run training path"
    if not args.ipex:
        # for offical pytorch, int8 and jit path is not enabled.
        assert not args.int8, "int8 path is not enabled for offical pytorch"
        assert not args.jit, "jit path is not enabled for offical pytorch"

    if not args.dummy:
        # Data loading code
        assert args.data != None, "please set dataset path if you want to using real data"
        valdir = os.path.join(args.data, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        if not args.evaluate:
            traindir = os.path.join(args.data, 'train')
            train_dataset = datasets.ImageFolder(
                traindir,
                transforms.Compose([
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    normalize,
                ]))

            if args.distributed:
                train_sampler = torch.utils.data.distributed.DistributedSampler(
                    train_dataset)
            else:
                train_sampler = None

            train_loader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=args.batch_size,
                shuffle=(train_sampler is None),
                num_workers=args.workers,
                pin_memory=True,
                sampler=train_sampler)

        val_loader = torch.utils.data.DataLoader(datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
                                                 batch_size=args.batch_size,
                                                 shuffle=False,
                                                 num_workers=args.workers,
                                                 pin_memory=True)
    else:
        train_loader = None
        val_loader = None

    if args.evaluate:
        if args.ipex:
            print("using ipex model to do inference\n")
        else:
            print("using offical pytorch model to do inference\n")

        if args.ipex:
            model.eval()
            if args.int8:
                if not args.calibration:
                    model = optimization.fuse(model, inplace=True)
                    conf = ipex.quantization.QuantConf(args.configure_dir)
                    x = torch.randn(
                        args.batch_size, 3, 224,
                        224).contiguous(memory_format=torch.channels_last)
                    model = ipex.quantization.convert(model, conf, x)
                    with torch.no_grad():
                        y = model(x)
                        print(model.graph_for(x))
                    print("running int8 evalation step\n")
            else:
                if args.bf16:
                    model = ipex.optimize(model,
                                          dtype=torch.bfloat16,
                                          inplace=True)
                    print("running bfloat16 evalation step\n")
                else:
                    model = ipex.optimize(model,
                                          dtype=torch.float32,
                                          inplace=True)
                    print("running fp32 evalation step\n")
                if args.jit:
                    x = torch.randn(
                        args.batch_size, 3, 224,
                        224).contiguous(memory_format=torch.channels_last)
                    if args.bf16:
                        x = x.to(torch.bfloat16)
                        with torch.cpu.amp.autocast(), torch.no_grad():
                            model = torch.jit.trace(model, x).eval()
                    else:
                        with torch.no_grad():
                            model = torch.jit.trace(model, x).eval()
                    model = torch.jit.freeze(model)
        validate(val_loader, model, criterion, args)
        return

    if args.ipex:
        if args.bf16:
            model, optimizer = ipex.optimize(model,
                                             dtype=torch.bfloat16,
                                             optimizer=optimizer)
        else:
            model, optimizer = ipex.optimize(model,
                                             dtype=torch.float32,
                                             optimizer=optimizer)

    # parallelize
    if args.distributed and not args.cuda and args.gpu is None:
        print("create DistributedDataParallel in CPU")
        device_ids = None
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=device_ids)

    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, args)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (
                args.multiprocessing_distributed
                and args.rank % ngpus_per_node == 0):
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_acc1': best_acc1,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
コード例 #7
0
def main():
    # model parameters
    parser = argparse.ArgumentParser("BNN regression example")
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--num_threads', type=int, default=1)
    parser.add_argument('--net_shape',
                        type=lambda s: [int(d) for d in s.split(',')],
                        default=[200, 200, 200, 200])
    parser.add_argument('--drop_rate', type=float, default=0.1)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--n_components', type=int, default=5)
    parser.add_argument('--N_batch', type=int, default=100)
    parser.add_argument('--train_iters', type=int, default=15000)
    parser.add_argument('--noise_level', type=float, default=1e-3)
    parser.add_argument('--resample', action='store_true')
    parser.add_argument('--use_cuda', action='store_true')
    args = parser.parse_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.set_num_threads(args.num_threads)
    torch.set_flush_denormal(True)

    if args.use_cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    # create toy dataset
    def f2d(X, multimodal=False):
        x, y = X[:, 0:1], X[:, 1:2]
        z = (1 + ((1 - x)**2 + (y - x**2)**2)).log()
        return torch.cat([z.sin(), z.cos()], -1)

    X = torch.stack(
        torch.meshgrid(torch.linspace(-0.35, 0.35, 25),
                       torch.linspace(-0.35, 0.35, 25))).T.reshape(-1, 2)
    X = torch.cat([
        X, X + torch.tensor([0.75, 0.75]), X + torch.tensor([0.75, -0.75]),
        X + torch.tensor([-0.75, -0.75]), X + torch.tensor([-0.75, 0.75])
    ], 0)
    Y = f2d(X)
    Y += torch.randn_like(Y) * args.noise_level
    xx = torch.stack(
        torch.meshgrid(torch.linspace(-1.5, 1.5, 50),
                       torch.linspace(-1.5, 1.5, 50))).T.reshape(-1, 2)
    yy = f2d(xx)

    fig = plt.figure()
    ax1 = fig.add_subplot(211, projection='3d')
    xx_, yy_ = xx.cpu(), yy.cpu()
    X_, Y_ = X.cpu(), Y.cpu()
    ax1.scatter(xx_[:, 0], xx_[:, 1], yy_[:, 0], s=1, alpha=0.25)
    ax1.scatter(X_[:, 0], X_[:, 1], Y_[:, 0])
    ax2 = fig.add_subplot(212, projection='3d')
    ax2.scatter(xx_[:, 0], xx_[:, 1], yy_[:, 1], s=1, alpha=0.25)
    ax2.scatter(X_[:, 0], X_[:, 1], Y_[:, 1])

    # plt.plot(xx.cpu(), -yy.cpu(), linestyle='--')
    print(('Dataset size:', X.shape[0], 'samples'))

    # # single gaussian model
    input_dims = 2
    output_dims = 2
    hids = args.net_shape

    model = models.density_network_mlp(
        input_dims, output_dims, models.GaussianDN, hids,
        [models.CDropout(args.drop_rate * torch.ones(hid))
         for hid in hids], models.activations.hhSinLU)
    model.set_scaling(X, Y)
    print(model)

    opt = torch.optim.Adam(model.parameters(), lr=args.lr)
    utils.train_model(model,
                      X,
                      Y,
                      n_iters=args.train_iters,
                      opt=opt,
                      resample=args.resample,
                      batch_size=args.N_batch)
    print(model)

    # # mixture of gaussians model
    nc = args.n_components
    mmodel = models.mixture_density_network_mlp(
        input_dims, output_dims, nc, models.GaussianMDN, hids,
        [models.CDropout(args.drop_rate * torch.ones(hid))
         for hid in hids], models.activations.hhSinLU)
    mmodel.set_scaling(X, Y)
    print(mmodel)

    opt = torch.optim.Adam(mmodel.parameters(), lr=args.lr)
    utils.train_model(mmodel,
                      X,
                      Y,
                      n_iters=args.train_iters,
                      opt=opt,
                      resample=args.resample,
                      batch_size=args.N_batch)
    print(mmodel)

    # plot results for single gaussian model
    xx = torch.stack(
        torch.meshgrid(torch.linspace(-2.75, 2.75, 50),
                       torch.linspace(-2.75, 2.75, 50))).T.reshape(-1, 2)
    yy = f2d(xx)
    xx_ = xx[:, None].repeat(1, 10, 1)
    with torch.no_grad():
        model.resample()
        py, py_params = model(xx_, temperature=1.0, resample=False)
        noiseless_py, noiseless_py_params = model(xx_,
                                                  temperature=1.0e-9,
                                                  resample=False)

    xx_, yy_ = xx.cpu(), yy.cpu()
    colors = np.array(list(plt.cm.rainbow_r(np.linspace(0, 1, 255))))
    fig = plt.figure(figsize=(16, 9))
    fig.canvas.set_window_title('Single Gaussian output density')
    samples = py.sample()
    noiseless_samples = noiseless_py.sample()

    for i in range(yy.shape[-1]):
        ax1 = fig.add_subplot(int(f'{yy.shape[-1]}1{i+1}'), projection='3d')
        ax1.scatter(xx_[:, 0], xx_[:, 1], yy_[:, i], s=2, alpha=0.25)
        ax1.scatter(X_[:, 0], X_[:, 1], Y_[:, i])
        ax1.scatter(xx_[:, 0:1].repeat(1, samples.shape[1]).view(-1),
                    xx_[:, 1:2].repeat(1, samples.shape[1]).view(-1),
                    noiseless_samples[..., i].view(-1),
                    s=2,
                    c=colors[0:1],
                    alpha=0.05)
        ax1.scatter(xx_[:, 0:1].repeat(1, samples.shape[1]),
                    xx_[:, 1:2].repeat(1, samples.shape[1]),
                    samples[..., i].view(-1),
                    s=2,
                    c=colors[0:1],
                    alpha=0.05)
        ax1.set_zlim3d(yy[:, i].min(), yy[:, i].max())

    # plot results for gaussian mixture model
    xx = torch.stack(
        torch.meshgrid(torch.linspace(-2.75, 2.75, 50),
                       torch.linspace(-2.75, 2.75, 50))).T.reshape(-1, 2)
    yy = f2d(xx)
    xx_ = xx[:, None].repeat(1, 10, 1)
    with torch.no_grad():
        mmodel.resample()
        py, py_params = mmodel(xx_, temperature=1.0, resample=False)
        noiseless_py, noiseless_py_params = mmodel(xx_,
                                                   temperature=1.0e-9,
                                                   resample=False)

    xx_, yy_ = xx.cpu(), yy.cpu()
    colors = np.array(list(plt.cm.rainbow_r(np.linspace(0, 1, 255))))
    samples = py.sample()
    noiseless_samples = py.sample()
    fig = plt.figure(figsize=(16, 9))
    fig.canvas.set_window_title('Mixture of Gaussians output density')
    samples = py.sample()
    noiseless_samples = noiseless_py.sample()
    logit_pi = py_params['logit_pi'].squeeze(-1).cpu()
    noiseless_logit_pi = noiseless_py_params['logit_pi'].squeeze(-1).cpu()

    for i in range(yy.shape[-1]):
        ax1 = fig.add_subplot(int(f'{yy.shape[-1]}1{i+1}'), projection='3d')
        mu = samples[:, :, i].mean(1).cpu()
        std = samples[:, :, i].std(1).cpu()
        ax1.scatter(xx_[:, 0], xx_[:, 1], yy_[:, i], s=2, alpha=0.25)
        ax1.scatter(X_[:, 0], X_[:, 1], Y_[:, i])
        ax1.scatter(xx_[:, 0:1].repeat(1, samples.shape[1]).view(-1),
                    xx_[:, 1:2].repeat(1, samples.shape[1]).view(-1),
                    noiseless_samples[..., i].view(-1),
                    s=20,
                    c=noiseless_logit_pi.argmax(-1).view(-1),
                    alpha=0.01)
        ax1.scatter(xx_[:, 0:1].repeat(1, samples.shape[1]).view(-1),
                    xx_[:, 1:2].repeat(1, samples.shape[1]).view(-1),
                    samples[..., i].view(-1),
                    s=20,
                    c=logit_pi.argmax(-1).view(-1),
                    alpha=0.1)
        ax1.set_zlim3d(yy[:, i].min(), yy[:, i].max())

    plt.show()
コード例 #8
0
ファイル: bnn_regression.py プロジェクト: mcgillmrl/prob_mbrl
def main():
    # model parameters
    parser = argparse.ArgumentParser("BNN regression example")
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--num_threads', type=int, default=1)
    parser.add_argument('--net_shape',
                        type=lambda s: [int(d) for d in s.split(',')],
                        default=[200, 200, 200, 200])
    parser.add_argument('--drop_rate', type=float, default=0.1)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--n_components', type=int, default=5)
    parser.add_argument('--N_batch', type=int, default=100)
    parser.add_argument('--train_iters', type=int, default=15000)
    parser.add_argument('--noise_level', type=float, default=0)
    parser.add_argument('--resample', action='store_true')
    parser.add_argument('--use_cuda', action='store_true')
    args = parser.parse_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.set_num_threads(args.num_threads)
    torch.set_flush_denormal(True)

    if args.use_cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    # create toy dataset
    def f(x, multimodal=False):
        c = 100
        if multimodal:
            c *= 2 * torch.randint_like(x, 2) - 1
        return c * sum([
            torch.sin(-2 * np.pi * (2 * k - 1) * x) / (2 * k - 1)
            for k in range(1, 3)
        ])

    # create training dataset
    train_x = torch.cat([
        torch.arange(-0.6, -0.25, 0.01),
        torch.arange(0.1, 0.45, 0.005),
        torch.arange(0.7, 1.25, 0.01)
    ])
    train_y = f(train_x, False)
    train_y += 0.01 * torch.randn(*train_y.shape)
    X = train_x[:, None]
    Y = train_y[:, None]
    Y = Y + torch.randn_like(Y) * args.noise_level
    plt.scatter(X.cpu(), Y.cpu())
    xx = torch.linspace(-.1 + X.min(), .1 + X.max())
    yy = f(xx)
    plt.plot(xx.cpu(), yy.cpu(), linestyle='--')
    # plt.plot(xx.cpu(), -yy.cpu(), linestyle='--')
    print(('Dataset size:', train_x.shape[0], 'samples'))

    # single gaussian model
    input_dims = 1
    output_dims = 1
    hids = args.net_shape

    model = models.density_network_mlp(
        input_dims, output_dims, models.GaussianDN, hids,
        [models.CDropout(args.drop_rate * torch.ones(hid))
         for hid in hids], models.activations.hhSinLU)
    model.set_scaling(X, Y)
    print(model)

    opt = torch.optim.Adam(model.parameters(), lr=args.lr)
    utils.train_model(model,
                      X,
                      Y,
                      n_iters=args.train_iters,
                      opt=opt,
                      resample=args.resample,
                      batch_size=args.N_batch)
    print(model)

    # mixture of gaussians model
    nc = args.n_components
    mmodel = models.mixture_density_network_mlp(
        input_dims, output_dims, nc, models.GaussianMDN, hids,
        [models.CDropout(args.drop_rate * torch.ones(hid))
         for hid in hids], models.activations.hhSinLU)
    mmodel.set_scaling(X, Y)
    print(mmodel)

    opt = torch.optim.Adam(mmodel.parameters(), lr=args.lr)
    utils.train_model(mmodel,
                      X,
                      Y,
                      n_iters=args.train_iters,
                      opt=opt,
                      resample=args.resample,
                      batch_size=args.N_batch)
    print(mmodel)

    # plot results for single gaussian model
    xx = torch.linspace(-2.5 + X.min(), 2.5 + X.max(), 500)
    xx = xx[:, None, None].repeat(1, 100, 1)
    with torch.no_grad():
        model.resample()
        py, py_params = model(xx, temperature=1.0, resample=False)

    xx = xx.cpu()
    colors = np.array(list(plt.cm.rainbow_r(np.linspace(0, 1, 255))))
    fig = plt.figure(figsize=(16, 9))
    fig.canvas.set_window_title('Single Gaussian output density')
    means = py_params['mu'].squeeze(-1).cpu()
    stds = py_params['sqrtSigma'].diagonal(0, -1, -2).squeeze(-1).cpu()
    ret = plt.plot(xx.squeeze(-1), means, c=colors[0], alpha=0.1)
    for i in range(means.shape[1]):
        plt.fill_between(xx[:, i].squeeze(-1),
                         means[:, i] - stds[:, i],
                         means[:, i] + stds[:, i],
                         color=0.5 * colors[0],
                         alpha=0.1)
    plt.scatter(X.cpu(), Y.cpu())
    yy = f(xx[:, 0]).cpu()
    plt.plot(xx[:, 0], yy, linestyle='--')
    ret = plt.ylim(2.5 * yy.min(), 1.5 * yy.max())

    # plot results for gaussian mixture model
    xx = torch.linspace(-2.5 + X.min(), 2.5 + X.max(), 500)
    xx = xx[:, None, None].repeat(1, 100, 1)
    with torch.no_grad():
        mmodel.resample()
        py, py_params = mmodel(xx, temperature=1.0, resample=False)
        noiseless_py, noiseless_py_params = mmodel(xx,
                                                   temperature=1.0e-9,
                                                   resample=False)
    xx = xx.cpu()
    fig = plt.figure(figsize=(16, 9))
    fig.canvas.set_window_title('Mixture of Gaussians output density')
    ax = fig.gca()
    colors = np.array(list(plt.cm.rainbow_r(np.linspace(0, 1, 255))))
    logit_pi = py_params['logit_pi'].squeeze(-1).cpu()
    noiseless_logit_pi = noiseless_py_params['logit_pi'].squeeze(-1).cpu()

    samples = py.sample([10])
    noiseless_samples = noiseless_py.sample([1])

    pi = torch.log_softmax(logit_pi, -1).exp()
    comp = pi.max(-1, keepdim=True)[1]
    ret = plt.scatter(xx.repeat(samples.shape[0], 1, 1).squeeze(-1),
                      samples.view(-1, samples.shape[-2]).cpu(),
                      c=logit_pi.argmax(-1).repeat(samples.shape[0], 1),
                      alpha=0.1,
                      s=1,
                      cmap='copper')
    ret = plt.scatter(xx.repeat(noiseless_samples.shape[0], 1, 1).squeeze(-1),
                      noiseless_samples.view(
                          -1, noiseless_samples.shape[-2]).cpu(),
                      c=noiseless_logit_pi.argmax(-1),
                      alpha=0.5,
                      s=1,
                      cmap='copper')

    plt.scatter(X.cpu(), Y.cpu())
    yy = f(xx[:, 0]).cpu()
    plt.plot(xx[:, 0], yy, linestyle='--')
    ret = plt.ylim(2.75 * yy.min(), 2.75 * yy.max())
    plt.show()
コード例 #9
0
x = torch.rand(5, 3)
print(x)
print(torch.is_tensor(x))
print(torch.is_storage(x))
print(torch.is_floating_point(x))

print('')
print(torch.get_default_dtype())  # torch.float32
print(torch.tensor([1.2, 3]).dtype)  # default is torch.float32
torch.set_default_dtype(torch.float64)
print(torch.tensor([1.2, 3]).dtype)

print('')
torch.set_default_dtype(torch.float64)
print(torch.get_default_dtype())
torch.set_default_tensor_type(torch.FloatTensor)
print(torch.get_default_dtype())

print('')
x = torch.tensor([5, 3])
print(x)
print(torch.numel(x))
x = torch.rand(5, 3)
print(x)
print(torch.numel(x))

print('')
print(torch.set_flush_denormal(True))
print(torch.set_flush_denormal(False))
コード例 #10
0
ファイル: train.py プロジェクト: uestc-hjw/MichiGAN
"""
Copyright (C) University of Science and Technology of China.
Licensed under the MIT License.
"""
import torch

if not torch.set_flush_denormal(True):
    print("Unable to set flush denormal")
    print("Pytorch compiled without advanced CPU")
    print("at: https://github.com/pytorch/pytorch/blob/84b275b70f73d5fd311f62614bccc405f3d5bfa3/aten/src/ATen/cpu/FlushDenormal.cpp#L13")
import sys
from collections import OrderedDict
from options.train_options import TrainOptions
import data
from util.iter_counter import IterationCounter
from util.visualizer import Visualizer
from trainers.pix2pix_trainer import Pix2PixTrainer


# parse options
opt = TrainOptions().parse()

# print options to help debugging
print(' '.join(sys.argv))

# load the dataset
dataloader = data.create_dataloader(opt)
if opt.unpairTrain:
    dataloader2 = data.create_dataloader(opt, 2)

# create trainer for our model
コード例 #11
0
def main():
    # Parameters

    # Paths are written in UNIX-like notation!
    # So write `C:\Users\user\GANerator` as `C:/Users/user/GANerator` or `~/GANerator`.

    # All parameters that take classes also accept strings of the class.

    # Only parameters in the 'Data and Models' section will be saved and loaded!

    parser = argparse.ArgumentParser()
    # Experiment specific
    # ===================
    parser.add_argument('--num_imgs',
                        help='How many images to generate.',
                        default=30000)
    parser.add_argument(
        '--exp_name',
        nargs='+',
        help=
        "File names for this experiment. If `None` or `''`, `append_time` is always `True`.",
        default=None)
    parser.add_argument(
        '--append_time',
        nargs='+',
        help=
        "Append the current time to the file names (to prevent overwriting).",
        default=True)
    parser.add_argument(
        '--load_dir',
        nargs='+',
        help=
        "Directory to load saved files from. If `save_dir` is `None`, this also acts as `save_dir`.",
        default='.')
    parser.add_argument(
        '--save_dir',
        nargs='+',
        help="Directory to save to. If `None`, use the value of `load_dir`.",
        default='.')

    parser.add_argument(
        '--load_exp',
        nargs='+',
        help=
        "Load the models and parameters from this experiment (previous `exp_name`). Also insert the optionally appended time (WIP: if this value is otherwise ambiguous). Set the parameters `models_file` or `params_file` below to use file names. If set to `True`, use `exp_name`. If `False` or `None`, do not load.",
        default=False)
    parser.add_argument(
        '--params_file',
        nargs='+',
        help=
        "Load parameters from this path. Set to `False` to not load. Priority over `load_exp`. Set to `True` to ignore this so it does not override `load_exp`.",
        default=True)
    parser.add_argument(
        '--models_file',
        nargs='+',
        help=
        "Load models from this path. Set to `False` to not load. Priority over `load_exp`. Set to `True` to ignore this so it does not override `load_exp`.",
        default=True)
    parser.add_argument(
        '--load_weights_only',
        nargs='+',
        help=
        "Load only the models' weights. To continue training, set this to `False`.",
        default=True)

    parser.add_argument(
        '--save_params',
        nargs='+',
        help="Save the parameters in the 'Data and Models' section to a file.",
        default=False)
    parser.add_argument(
        '--save_weights_only',
        nargs='+',
        help=
        "Save only the models' weights. To continue training later, set this to `False`.",
        default=False)
    parser.add_argument(
        '--checkpoint_period',
        nargs='+',
        help=
        "After how many steps to save a model checkpoint. Set to `0` to only save when finished.",
        default=100)

    parser.add_argument(
        '--num_eval_imgs',
        nargs='+',
        help="How many images to generate for (temporal) evaluation.",
        default=64)

    # Hardware and Multiprocessing
    # ============================
    parser.add_argument(
        '--num_workers',
        nargs='+',
        help=
        "Amount of worker threads to create on the CPU. Set to `0` to use CPU count.",
        default=0)
    parser.add_argument(
        '--num_gpus',
        nargs='+',
        help=
        "Amount of GPUs to use. `None` to use all available ones. Set to `0` to run on CPU only.",
        default=None)
    parser.add_argument(
        '--cuda_device_id',
        nargs='+',
        help="ID of CUDA device. In most cases, this should be left at `0`.",
        default=0)

    # Reproducibility
    # ===============
    parser.add_argument(
        '--seed',
        nargs='+',
        help=
        "Random seed if `None`. The used seed will always be saved in `saved_seed`.",
        default=0)
    parser.add_argument(
        '--ensure_reproducibility',
        nargs='+',
        help=
        "If using cuDNN: Set to `True` to ensure reproducibility in favor of performance.",
        default=False)
    parser.add_argument(
        '--flush_denormals',
        nargs='+',
        help=
        "Whether to set denormals to zero. Some architectures do not support this.",
        default=True)

    # Data and Models
    # ===============
    # Only parameters in this section will be saved and updated when loading.

    parser.add_argument(
        '--dataset_root',
        nargs='+',
        help=
        "Path to the root folder of the data set. This value is only loaded if set to `None`!",
        default='~/datasets/ffhq')
    parser.add_argument(
        '--dataset_class',
        nargs='+',
        help=
        "Set this to the torchvision.datasets class (module `dsets`). This value is only loaded if set to `None`!",
        default=dsets.ImageFolder)
    parser.add_argument('--epochs',
                        nargs='+',
                        help="Number of training epochs.",
                        default=5)
    parser.add_argument(
        '--batch_size',
        nargs='+',
        help=
        "Size of each training batch. Strongly depends on other parameters.",
        default=512)
    parser.add_argument(
        '--img_channels',
        nargs='+',
        help=
        "Number of channels in the input images. Normally 3 for RGB and 1 for grayscale.",
        default=3)
    parser.add_argument(
        '--img_shape',
        nargs='+',
        help=
        "Shape of the output images (excluding channel dimension). Can be an integer to get squares. At the moment, an image can only be square sized and a power of two.",
        default=64)
    parser.add_argument(
        '--resize',
        nargs='+',
        help="If `True`, resize images; if `False`, crop (to the center).",
        default=True)

    parser.add_argument('--data_mean',
                        nargs='+',
                        help="Data is normalized to this mean (per channel).",
                        default=0.0)
    parser.add_argument(
        '--data_std',
        nargs='+',
        help="Data is normalized to this standard deviation (per channel).",
        default=1.0)
    parser.add_argument('--float_dtype',
                        nargs='+',
                        help="Float precision as `torch.dtype`.",
                        default=torch.float32)
    parser.add_argument(
        '--g_input',
        nargs='+',
        help="Size of the generator's random input vectors (`z` vector).",
        default=128)

    # GAN hacks
    parser.add_argument(
        '--g_flip_labels',
        nargs='+',
        help="Switch labels for the generator's training step.",
        default=False)
    parser.add_argument(
        '--d_noisy_labels_prob',
        nargs='+',
        help="Probability to switch labels when training the discriminator.",
        default=0.0)
    parser.add_argument(
        '--smooth_labels',
        nargs='+',
        help="Replace discrete labels with slightly different continuous ones.",
        default=False)

    # Values in this paragraph can be either a single value (e.g. an `int`) or a 2-`tuple` of the same type.
    # If a single value, that value will be applied to both the discriminator and generator network.
    # If a 2-`tuple`, the first value will be applied to the discriminator, the second to the generator.
    parser.add_argument(
        '--features',
        nargs='+',
        help="Relative size of the network's internal features.",
        default=64)
    parser.add_argument(
        '--optimizer',
        nargs='+',
        help="Optimizer class. GAN hacks recommends `(optim.SGD, optim.Adam)`.",
        default=optim.Adam)
    parser.add_argument(
        '--lr',
        nargs='+',
        help=
        "Optimizer learning rate. (Second optimizer argument, so not necessarily learning rate.)",
        default=0.0002)
    parser.add_argument(
        '--optim_param',
        nargs='+',
        help=
        "Third optimizer argument. (For example, `betas` for `Adam` or `momentum` for `SGD`.)",
        default=((0.5, 0.999), ))
    parser.add_argument(
        '--optim_kwargs',
        nargs='+',
        help="Any further optimizer keyword arguments as a dictionary.",
        default={})
    parser.add_argument(
        '--normalization',
        nargs='+',
        help=
        "Kind of normalization. Must be a `Norm` or in `('b', 'v', 's', 'i', 'a', 'n')`. Usually, spectral normalization is used in the discriminator while virtual batch normalization is used in the generator.",
        default=Norm.BATCH)
    parser.add_argument(
        '--activation',
        nargs='+',
        help=
        "Activation between hidden layers. GAN hacks recommends `nn.LeakyReLU`.",
        default=(nn.LeakyReLU, nn.ReLU))
    parser.add_argument('--activation_kwargs',
                        nargs='+',
                        help="Activation keyword arguments.",
                        default=({
                            'negative_slope': 0.2,
                            'inplace': True
                        }, {
                            'inplace': True
                        }))
    params = vars(parser.parse_args())
    for key, val in params.items():
        if type(val) is list:
            if len(val) == 1:
                params[key] = val[0]
            else:
                params[key] = tuple(val)

    # Process parameters

    num_imgs = int(params['num_imgs'])

    # Model parameters as tuples. If it is a tuple, give the class to return as well.
    # If the class is given as `'eval'`, the parameter is literally evaluated if either
    # the tuple or its content begins with a symbol in '({['.
    tuple_params = (
        ('features', int),
        ('optimizer', 'eval'),
        ('lr', float),
        ('optim_param', 'eval'),
        ('optim_kwargs', 'eval'),
        ('normalization', 'eval'),
        ('activation', 'eval'),
        ('activation_kwargs', 'eval'),
    )

    # Parameters that we do *not* want to save (or load).
    # We list these instead of the model parameters as those should be easier to extend.
    static_params = [
        'exp_name',
        'append_time',
        'load_dir',
        'save_dir',
        'load_exp',
        'params_file',
        'models_file',
        'load_weights_only',
        'save_params',
        'save_weights_only',
        'checkpoint_period',
        'num_workers',
        'num_gpus',
        'cuda_device_id',
        'seed',
        'ensure_reproducibility',
        'flush_denormals',
    ]

    def string_to_class(string):
        if type(string) is str:
            string = string.split('.')
            if len(string) == 1:
                m = __builtins__
            else:
                m = globals()[string[0]]
                for part in string[1:-1]:
                    m = getattr(m, part)
            return getattr(m, string[-1])
        else:
            return string

    def argstring(string):
        """
        Return a string converted to its value as if evaled or itself.

        `string` is converted to:
        - the corresponding boolean if it is `'True'` or `'False'`
        - None if `'None'`
        - nothing and returned as it is otherwise.
        """
        return {'True': True, 'False': False, 'None': None}.get(string, string)

    # Experiment name

    append_time = argstring(params['append_time'])
    exp_name = argstring(params['exp_name'])
    if not exp_name or append_time:
        if exp_name is not str:
            exp_name = ''
        exp_name = ''.join(
            (exp_name, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')))

    # Load parameters

    load_dir = argstring(params['load_dir'])
    save_dir = argstring(params['save_dir'])
    if save_dir is None:
        save_dir = load_dir

    load_exp = argstring(params['load_exp'])

    params_file = argstring(params['params_file'])
    load_params = params_file and (load_exp or type(params_file) is str)

    dataset_root = argstring(params['dataset_root'])
    dataset_class = string_to_class(params['dataset_class'])

    # Check whether these parameters are `None`.
    # If yes, check that parameters loading is enabled. Otherwise do not update them.
    if dataset_root is None:
        assert load_params, '`dataset_root` cannot be `None` if not loading parameters.'
    else:
        static_params.append('dataset_root')
    if dataset_class is None:
        assert load_params, '`dataset_class` cannot be `None` if not loading parameters.'
    else:
        static_params.append('dataset_class')

    if params_file and (load_exp or type(params_file) is str):
        if type(params_file) is str:
            params_path = Path(params_file)
        elif type(load_exp) is bool:  #
            params_path = Path('{}/params_{}.pt'.format(load_dir, exp_name))
        else:
            params_path = Path('{}/params_{}.pt'.format(load_dir, load_exp))

        params_path = params_path.expanduser()
        upd_params = torch.load(params_path)
        params.update(upd_params)
        del upd_params
    elif params_file == '':
        print(
            "`params_file` is an empty string (`''`). Parameters were not loaded. "
            'Set to `False` to suppress this warning or to `True` to let `load_exp` handle loading.'
        )

    # Hardware and multiprocessing

    num_gpus = argstring(params['num_gpus'])
    cuda_device_id = int(params['cuda_device_id'])
    if num_gpus is None:
        num_gpus = torch.cuda.device_count()
        print('Using {} GPUs.'.format(num_gpus))
    else:
        num_gpus = int(num_gpus)
    use_gpus = num_gpus > 0
    multiple_gpus = num_gpus > 1
    if use_gpus:
        assert torch.cuda.is_available(), 'CUDA is not available. ' \
                'Check what is wrong or set `num_gpus` to `0` to run on CPU.'  # Never check for this again
        device = torch.device('cuda:' + str(cuda_device_id))
    else:
        device = torch.device('cpu')

    num_workers = int(params['num_workers'])
    if not num_workers:
        num_workers = mp.cpu_count()
        print('Using {} worker threads.'.format(num_workers))

    # Load model

    models_file = argstring(params['models_file'])
    models_cp = None
    if models_file and (load_exp or type(models_file) is str):
        if type(models_file) is str:
            models_path = Path(models_file)
        elif type(load_exp) is bool:
            models_path = Path('{}/models_{}.tar'.format(load_dir, exp_name))
        else:
            models_path = Path('{}/models_{}.tar'.format(load_dir, load_exp))
        models_path = models_path.expanduser()
        models_cp = torch.load(models_path, map_location=device)
    elif models_file == '':
        print(
            "`models_file` is an empty string (`''`). Models were not loaded. "
            'Set to `False` to suppress this warning or to `True` to let `load_exp` handle loading.'
        )

    # Reproducibility

    seed = argstring(params['seed'])
    if seed is None:
        seed = np.random.randint(10000)
    else:
        seed = int(seed)
    print('Seed: {}.'.format(seed))
    params['saved_seed'] = seed
    np.random.seed(seed)
    torch.manual_seed(seed)

    ensure_reproducibility = argstring(params['ensure_reproducibility'])
    torch.backends.cudnn.deterministic = ensure_reproducibility
    if ensure_reproducibility:
        torch.backends.cudnn.benchmark = False  # This is the default but do it anyway

    flush_denormals = argstring(params['flush_denormals'])
    set_flush_success = torch.set_flush_denormal(flush_denormals)
    if flush_denormals and not set_flush_success:
        print('Not able to flush denormals. `flush_denormals` set to `False`.')
        flush_denormals = False

    # Dataset root

    dataset_root = Path(dataset_root).expanduser()

    # Floating point precision

    float_dtype = string_to_class(params['float_dtype'])
    if float_dtype is torch.float16:
        print(
            'PyTorch does not support half precision well yet. Be careful and assume errors.'
        )
    torch.set_default_dtype(float_dtype)

    # Parameters we do not need to process

    load_weights_only = argstring(params['load_weights_only'])
    save_weights_only = argstring(params['save_weights_only'])
    checkpoint_period = int(params['checkpoint_period'])
    num_eval_imgs = int(params['num_eval_imgs'])

    epochs = int(params['epochs'])
    batch_size = int(params['batch_size'])
    img_channels = int(params['img_channels'])
    resize = argstring(params['resize'])

    data_mean = float(params['data_mean'])
    data_std = float(params['data_std'])
    g_input = int(params['g_input'])

    g_flip_labels = argstring(params['g_flip_labels'])
    d_noisy_labels_prob = float(params['d_noisy_labels_prob'])
    smooth_labels = argstring(params['smooth_labels'])

    assert 0.0 <= d_noisy_labels_prob <= 1.0, \
            'Invalid probability for `d_noisy_labels`. Must be between 0 and 1 inclusively.'

    # Single or tuple parameters

    def param_as_ntuple(key, n=2, return_type=None):
        if return_type is None:

            def return_func(x):
                return x
        else:
            return_func = return_type
        val = params[key]
        if return_type == 'eval':
            if type(val) is str and val[0] in '({[':
                val = literal_eval(val)

            def return_func(x):
                if type(x) is str and x[0] in '({[':
                    return literal_eval(str(x))
                else:
                    return x

        if type(val) in (tuple, list):
            assert 0 < len(
                val
            ) <= n, 'Tuples should have length {} (`{}` is `{}`).'.format(
                n, key, val)
            if len(val) < n:
                if len(val) > 1:
                    print('`{}` is `{}`. Length is less than {}; '.format(
                        key, val, n) +
                          'last entry has been repeated to fit length.')
                return tuple(
                    map(return_func,
                        tuple(val) + (val[-1], ) * (n - len(val))))
            else:
                return tuple(map(return_func, val))
        return (return_func(val), ) * n

    def ispow2(x):
        log2 = np.log2(x)
        return log2 == int(log2)

    img_shape = param_as_ntuple('img_shape', return_type=int)
    assert img_shape[0] == img_shape[
        1], '`img_shape` must be square (same width and height).'
    assert ispow2(img_shape[0]), '`img_shape` must be a power of two (2^n).'

    d_params = {}
    g_params = {}
    for key in tuple_params:
        if type(key) is tuple:
            key, ret_type = key
            d_params[key], g_params[key] = param_as_ntuple(
                key, return_type=ret_type)
        else:
            d_params[key], g_params[key] = param_as_ntuple(key)

    # Normalization and class parameters

    for p in d_params, g_params:
        normalization = p['normalization']
        if isinstance(normalization,
                      str) and normalization.lower() in ('b', 'v', 's', 'i',
                                                         'a', 'n'):
            normalization = {
                'b': Norm.BATCH,
                'v': Norm.VIRTUAL_BATCH,
                's': Norm.SPECTRAL,
                'i': Norm.INSTANCE,
                'a': Norm.AFFINE_INSTANCE,
                'n': Norm.NONE
            }[normalization]
        if not isinstance(normalization, Norm):
            try:
                normalization = Norm(normalization)
            except ValueError:
                normalization = string_to_class(normalization)
            finally:
                assert isinstance(normalization, Norm), \
                        "Unknown normalization. Must be a `Norm` or in `('b', 'v', 's', 'i', 'a', 'n')`."
        p['normalization'] = normalization

        p['optimizer'] = string_to_class(p['optimizer'])
        p['activation'] = string_to_class(p['activation'])

    save_models_path_str = '{}/models_{}_{{}}_steps.tar'.format(
        save_dir, exp_name)

    # Generate example batch

    example_noise = torch.randn(batch_size, g_input, 1, 1, device=device)

    # Model helper methods

    @weak_module
    class VirtualBatchNorm2d(nn.Module):
        def __init__(self, num_features, eps=1e-5, affine=True):
            super().__init__()
            self.num_features = num_features
            self.eps = eps
            self.affine = affine
            if self.affine:
                self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
                self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            else:
                self.register_parameter('weight', None)
                self.register_parameter('bias', None)
            self.reset_parameters(True)

        def reset_parameters(self, all=False):
            if self.affine:
                nn.init.uniform_(self.weight)
                nn.init.zeros_(self.bias)
            if all:
                self.in_coef = None
                self.ref_coef = None

        @weak_script_method
        def forward(self, input, ref_batch):
            self._check_input_dim(input)
            if self.in_coef is None:
                self._check_input_dim(ref_batch)
                self.in_coef = 1 / (len(ref_batch) + 1)
                self.ref_coef = 1 - self.in_coef

            mean, std, ref_mean, ref_std = self.calculate_statistics(
                input, ref_batch)
            return self.normalize(input, mean,
                                  std), self.normalize(ref_batch, ref_mean,
                                                       ref_std)

        @weak_script_method
        def calculate_statistics(self, input, ref_batch):
            in_mean, in_sqmean = self.calculate_means(input)
            ref_mean, ref_sqmean = self.calculate_means(ref_batch)

            mean = self.in_coef * in_mean + self.ref_coef * ref_mean
            sqmean = self.in_coef * in_sqmean + self.ref_coef * ref_sqmean

            std = torch.sqrt(sqmean - mean**2 + self.eps)
            ref_std = torch.sqrt(ref_sqmean - ref_mean**2 + self.eps)
            return mean, std, ref_mean, ref_std

        # TODO could be @staticmethod, but check @weak_script_method first
        @weak_script_method
        def calculate_means(self, batch):
            mean = torch.mean(batch, 0, keepdim=True)
            sqmean = torch.mean(batch**2, 0, keepdim=True)
            return mean, sqmean

        @weak_script_method
        def normalize(self, batch, mean, std):
            return ((batch - mean) / std) * self.weight + self.bias

        @weak_script_method
        def _check_input_dim(self, input):
            if input.dim() != 4:
                raise ValueError('expected 4D input (got {}D input)'.format(
                    input.dim()))

    def powers(n, b=2):
        """Yield `n` powers of `b` starting from `b**0`."""
        x = 1
        for i in range(n):
            x_old = x
            x *= b
            yield x_old, x

    def layer_with_norm(layer, norm, features):
        if norm is Norm.BATCH:
            return (layer, nn.BatchNorm2d(features))
        elif norm is Norm.VIRTUAL_BATCH:
            return (layer, VirtualBatchNorm2d(features))
        elif norm is Norm.SPECTRAL:
            return (nn.utils.spectral_norm(layer), )
        elif norm is Norm.INSTANCE:
            return (layer, nn.InstanceNorm2d(features))
        elif norm is Norm.AFFINE_INSTANCE:
            return (layer, nn.InstanceNorm2d(features, affine=True))
        elif norm is Norm.NONE:
            return (layer, )
        else:
            raise ValueError("Unknown normalization `'{}'`".format(norm))

    # Define and initialize generator

    # Generator

    class Generator(nn.Module):
        def __init__(self,
                     normalization,
                     activation,
                     activation_kwargs,
                     img_channels,
                     img_shape,
                     features,
                     g_input,
                     reference_batch=None):
            super().__init__()
            self.layers = self.build_layers(normalization, activation,
                                            activation_kwargs, img_channels,
                                            img_shape, features, g_input)
            if normalization is not Norm.VIRTUAL_BATCH:
                self.reference_batch = None  # we can test for VBN with this invariant
                self.layers = nn.Sequential(*self.layers)
            elif reference_batch is None:
                raise ValueError('Normalization is virtual batch norm, but '
                                 '`reference_batch` is `None` or missing.')
            else:
                self.reference_batch = reference_batch  # never `None`
                self.layers = nn.ModuleList(self.layers)

        @staticmethod
        def build_layers(norm, activation, activation_kwargs, img_channels,
                         img_shape, features, g_input):
            """
            Return a list of the layers for the generator network.

            Example for a 64 x 64 image:
            >>> Generator.build_layers(Norm.BATCH, nn.ReLU, {'inplace': True},
                                       img_channels=3, img_shape=(64, 64), features=64, g_input=128)
            [
                # input size is 128 (given by `g_input`)
                nn.ConvTranspose2d(g_input, features * 8, 4, 1, 0, bias=False),
                nn.BatchNorm2d(features * 8),
                nn.ReLU(True),
                # state size is (features * 8) x 4 x 4
                nn.ConvTranspose2d(features * 8, features * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(features * 4),
                nn.ReLU(True),
                # state size is (features * 4) x 8 x 8
                nn.ConvTranspose2d(features * 4, features * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(features * 2),
                nn.ReLU(True),
                # state size is (features * 2) x 16 x 16
                nn.ConvTranspose2d(features * 2, features, 4, 2, 1, bias=False),
                nn.BatchNorm2d(features),
                nn.ReLU(True),
                # state size is (features) x 32 x 32
                nn.ConvTranspose2d(features, img_channels, 4, 2, 1, bias=False),
                nn.Tanh()
                # output size is 3 x 64 x 64 (given by `img_channels` and `img_shape`)
            ]
            """
            j = 2**(int(np.log2(img_shape[0])) - 3)
            # input size is (g_input)
            layers = [
                *layer_with_norm(
                    nn.ConvTranspose2d(
                        g_input, features * j, 4, 1, 0, bias=False), norm,
                    features * j),
                activation(**activation_kwargs)
            ]
            # state size is (features * 2^n) x 4 x 4
            # each further layer halves feature size and doubles image size
            while j > 1:
                i = j
                j //= 2
                layers.extend((*layer_with_norm(
                    nn.ConvTranspose2d(
                        features * i, features * j, 4, 2, 1, bias=False), norm,
                    features * j), activation(**activation_kwargs)))
            # state size is (features) x (img_shape[0] / 2) x (img_shape[1] / 2)
            layers.extend((nn.ConvTranspose2d(features,
                                              img_channels,
                                              4,
                                              2,
                                              1,
                                              bias=False), nn.Tanh()))
            # output size is (img_channels) x (img_shape[0]) x (img_shape[1])
            return layers

        @weak_script_method
        def forward(self, input):
            # Separation is for performance reasons
            if self.reference_batch is None:
                return self.layers(input)
            else:
                # VBN
                ref_batch = self.reference_batch
                for layer in self.layers:
                    if not isinstance(layer, VirtualBatchNorm2d):
                        input = layer(input)
                        ref_batch = layer(ref_batch)
                    else:
                        input, ref_batch = layer(input, ref_batch)
                return input

    # Initialization

    def init_weights(module):
        if isinstance(module, ConvBase):
            nn.init.normal_(module.weight.data, 0.0, 0.02)
        elif isinstance(module, BatchNormBase):
            nn.init.normal_(module.weight.data, 1.0, 0.02)
            nn.init.constant_(module.bias.data, 0)

    g_net = Generator(g_params['normalization'], g_params['activation'],
                      g_params['activation_kwargs'], img_channels, img_shape,
                      g_params['features'], g_input,
                      example_noise.to(device,
                                       float_dtype)).to(device, float_dtype)

    # Load models' checkpoints

    if models_cp is not None:
        g_net.load_state_dict(models_cp['g_net_state_dict'])

    if multiple_gpus:
        g_net = nn.DataParallel(g_net, list(range(num_gpus)))

    if models_cp is None:
        g_net.apply(init_weights)

    real_label = 1
    fake_label = 0

    # Load optimizers' checkpoints

    if models_cp is not None:
        if not load_weights_only:
            try:
                g_optim_state_dict = models_cp['g_optim_state_dict']
            except KeyError:
                print(
                    "One of the optimizers' state dicts was not found; probably because "
                    "only the models' weights were saved. Set `load_weights_only` to `True`."
                )
            g_optimizer.load_state_dict(g_optim_state_dict)
            g_net.train()
        else:
            g_net.eval()

    def generate_fakes(batch_size, start_count, g_input, g_net, device,
                       float_dtype, zfill_len, save_dir):
        i = start_count
        noise = torch.randn(batch_size, g_input, 1, 1,
                            device=device).to(device, float_dtype)
        with torch.no_grad():
            fakes = g_net(noise).detach().cpu()

        for f in fakes:
            tvutils.save_image(
                f, Path('{}/{}.png'.format(save_dir,
                                           str(i).zfill(zfill_len))))
            i += 1

    # Save images generated on noise.
    zfill_len = len(str(num_imgs))

    num_full_gens = num_imgs // batch_size
    for i in map(lambda x: x * batch_size, range(num_full_gens)):
        generate_fakes(batch_size, i, g_input, g_net, device, float_dtype,
                       zfill_len, save_dir)
    generate_fakes(num_imgs - num_full_gens * batch_size,
                   num_full_gens * batch_size, g_input, g_net, device,
                   float_dtype, zfill_len, save_dir)
コード例 #12
0
from time import time
from os.path import isdir
from glob import glob
import sys
import cv2
import numpy as np
import torch
from torch import load, set_flush_denormal
from vision.ssd.fpnnet_ssd import create_fpnnet_ssd, create_fpn_ssd_predictor
from vision.utils.misc import Timer
from tracker import Tracker
from utils import Flags

set_flush_denormal(True)
if len(sys.argv) < 3:
    print(
        'Usage: python run_ssd_example.py <net> <model path> <label path> [video file]'
    )
    sys.exit(0)
model_path = sys.argv[1]
label_path = sys.argv[2]

is_dir = False
files = []
if len(sys.argv) >= 4:
    if isdir(sys.argv[3]):
        files = [f for f in glob('%s/*' % sys.argv[3])]
        is_dir = True
    else:
        cap = cv2.VideoCapture(sys.argv[3])  # capture from file
else:
コード例 #13
0
    def __init__(self,
                 weights_path,
                 key_index,
                 tok_to_id,
                 class_index,
                 testing_mode=True,
                 device="-1",
                 model_backbone='Attention_GCN_backbone',
                 number_adj=4,
                 num_classes=3):

        self.model_backbone = model_backbone

        if testing_mode:
            assert isinstance(key_index, str) and isinstance(tok_to_id, str)

            try:
                with open(key_index, 'rb') as file:
                    self.key_index = pickle.load(file)
            except:
                with open(key_index, 'r') as file:
                    self.key_index = json.load(file)
            try:
                with open(tok_to_id, 'rb') as file:
                    self.tok_to_id = pickle.load(file)
            except:
                with open(tok_to_id, 'r') as file:
                    self.tok_to_id = json.load(file)

            try:
                with open(class_index, 'rb') as file:
                    self.class_index = pickle.load(file)
            except:
                with open(class_index, 'r') as file:
                    self.class_index = json.load(file)

        else:
            self.key_index = key_index
            self.tok_to_id = tok_to_id
            self.class_index = class_index

        self.index_key = {value: key for key, value in self.key_index.items()}

        self.index_class = {value: key for key, value in self.class_index.items()}

        # for key in list(self.key_index.keys()):
        #     self.index_key[self.key_index[key]] = key


        self.all_character_in_dic = list(self.tok_to_id.keys())

        backbone = eval(model_backbone)
        print(backbone)
        self.model = backbone(
            len(self.all_character_in_dic) + 4, number_adj,
            len(self.key_index) + 1, 
            self.key_index,
            num_classification=num_classes)

        # elif model_backbone == 'GCN':
        #     self.model = GCN(
        #         len(self.all_character_in_dic) + 4, number_adj,
        #         len(self.key_index) + 1)
        # elif model_backbone == 'Spatial_Attention_GCN':
        #     self.model = Spatial_Attention_GCN(
        #         len(self.all_character_in_dic) + 4, number_adj,
        #         len(self.key_index) + 1)
        # else:
        #     raise ImportError('No avaliable backbone')

        if device != '-1' and torch.cuda.is_available():
            device_list = device.split(",")
            self.device = f"cuda:{device_list[0]}"
            self.device_number = device
        else:
            self.device = 'cpu'

        try:
            self.model.load_state_dict(torch.load(weights_path, map_location="cpu"))
            # self.model.eval()
            self.model.to(self.device)
            logger.info('[classification] loading successfully....')
        except:
            logger.info(
                '[classification] Can not load the weight, please input the proper weight....')

        # torch.set_num_threads(24)
        torch.set_flush_denormal(True)

        self.eval()
コード例 #14
0
ファイル: main.py プロジェクト: zhuangzhong/RankIQA.PyTorch
def main_worker(gpu, args):
    """ 模型训练、测试、转JIT、蒸馏文件制作
    :param gpu: 运行的gpu id
    :param args: 运行超参
    """
    args.gpu = gpu
    utils.generate_logger(
        f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{gpu}.log")
    logging.info(f'args: {args}')

    # 可复现性
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        logging.warning('You have chosen to seed training. '
                        'This will turn on the CUDNN deterministic setting, '
                        'which can slow down your training considerably! '
                        'You may see unexpected behavior when restarting '
                        'from checkpoints.')

    if args.cuda:
        logging.info(f"Use GPU: {args.gpu} ~")
        if args.distributed:
            args.rank = args.rank * args.gpus + gpu
            dist.init_process_group(backend='nccl',
                                    init_method=args.init_method,
                                    world_size=args.world_size,
                                    rank=args.rank)
    else:
        logging.info(f"Use CPU ~")

    # 创建/加载模型,使用预训练模型时,需要自己先下载好放到 pretrained 文件夹下,以网络名词命名
    logging.info(f"=> creating model '{args.arch}'")
    model = my_models.get_model(args.arch,
                                args.pretrained,
                                num_classes=args.num_classes)

    # 重加载之前训练好的模型
    if args.resume:
        if os.path.isfile(args.resume):
            logging.info(f"=> loading checkpoint '{args.resume}'")
            checkpoint = torch.load(args.resume,
                                    map_location=torch.device('cpu'))
            acc = model.load_state_dict(checkpoint['state_dict'])
            logging.info(f'missing keys of models: {acc.missing_keys}')
            del checkpoint
        else:
            raise Exception(
                f"No checkpoint found at '{args.resume}' to be resumed")

    # 模型信息
    image_height, image_width = args.image_size
    logging.info(
        f'Model {args.arch} input size: ({image_height}, {image_width})')
    utils.summary(size=(image_height, image_width), channel=3, model=model)

    # 模型转换:转为 torch.jit.script
    if args.jit:
        if not args.resume:
            raise Exception('Option --resume must specified!')
        applications.convert_to_jit(model, args=args)
        return

    if args.criterion == 'rank':
        criterion = criterions.RankingLoss(args=args)  # 对比排序损失
    elif args.criterion == 'emd':
        criterion = criterions.EMDLoss()  # 推土机距离损失
    elif args.criterion == 'regress':
        criterion = criterions.RegressionLoss()  # MSE回归损失
    else:
        raise NotImplementedError(
            f'Not loss function {args.criterion},only (rank, emd, regress)!')

    if args.cuda:
        if args.distributed and args.sync_bn:
            model = apex.parallel.convert_syncbn_model(model)
        torch.cuda.set_device(args.gpu)
        model.cuda(args.gpu)
        criterion = criterion.cuda(args.gpu)

    # 优化器:Adam > SGD > SWA(SGD > Adam)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
    # 可尝试优化器
    # optimizer = torch.optim.SGD(model.parameters(),
    #                             args.lr, momentum=args.momentum,
    #                             weight_decay=args.weight_decay)
    # from optim.torchtools.optim import RangerLars, Ralamb, Novograd, LookaheadAdam, Ranger, RAdam, AdamW
    # optimizer = RangerLars(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # optimizer = Ralamb(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # optimizer = Novograd(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # optimizer = LookaheadAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # optimizer = Ranger(model_params, lr=args.lr, weight_decay=args.weight_decay)
    # optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # 随机均值平均优化器
    # from optim.swa import SWA
    # optimizer = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.05)

    # 混合精度训练
    if args.cuda:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        model = DDP(model)
    else:
        model = torch.nn.DataParallel(model)

    if args.train:
        train_loader = dataloader.load(args, 'train')
        val_loader = dataloader.load(args, 'val')
        scheduler = LambdaLR(
            optimizer, lambda epoch: adjust_learning_rate(epoch, args=args))
        applications.train(train_loader, val_loader, model, criterion,
                           optimizer, scheduler, args)
        args.evaluate = True

    if args.evaluate:
        torch.set_flush_denormal(True)
        test_loader = dataloader.load(args, name='test')
        acc, loss, test_results = applications.test(test_loader, model,
                                                    criterion, args)
        logging.info(f'Evaluation: * Acc@1 {acc:.3f} and loss {loss:.3f}.')
        logging.info(f'Evaluation results:')
        for result in test_results:
            logging.info(' '.join([str(r) for r in result]))
        logging.info('Evaluation Over~')
コード例 #15
0
def main():
    # model parameters
    parser = argparse.ArgumentParser("BNN regression example")
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--num_threads', type=int, default=1)
    parser.add_argument('--net_shape',
                        type=lambda s: [int(d) for d in s.split(',')],
                        default=[200, 200])
    parser.add_argument('--drop_rate', type=float, default=0.1)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--n_components', type=int, default=5)
    parser.add_argument('--N_batch', type=int, default=100)
    parser.add_argument('--train_iters', type=int, default=10000)
    parser.add_argument('--noise_level', type=float, default=1e-1)
    parser.add_argument('--resample', action='store_true')
    parser.add_argument('--use_cuda', action='store_true')
    args = parser.parse_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.set_num_threads(args.num_threads)
    torch.set_flush_denormal(True)

    idims, odims = 1, 1
    # single gaussian output model
    mlp = models.mlp(idims,
                     2 * odims,
                     args.net_shape,
                     dropout_layers=[
                         models.CDropout(args.drop_rate * np.ones(hid))
                         for hid in args.net_shape
                     ])
    model = models.Regressor(mlp,
                             output_density=models.DiagGaussianDensity(odims))

    # mixture density network
    mlp2 = models.mlp(idims,
                      2 * args.n_components * odims + args.n_components + 1,
                      args.net_shape,
                      dropout_layers=[
                          models.CDropout(args.drop_rate * np.ones(hid))
                          for hid in args.net_shape
                      ])
    mmodel = models.Regressor(mlp2,
                              output_density=models.GaussianMixtureDensity(
                                  odims, args.n_components))

    # optimizer for single gaussian model
    opt1 = torch.optim.Adam(model.parameters(), args.lr)

    # optimizer for mixture density network
    opt2 = torch.optim.Adam(mmodel.parameters(), args.lr)

    # create training dataset
    train_x = np.concatenate([
        np.linspace(-1.6, -0.25, 100),
        np.linspace(0.1, 0.25, 100),
        np.linspace(0.65, 1.0, 100)
    ])
    train_y = f(train_x)
    train_y += args.noise_level * np.random.randn(*train_y.shape)
    X = torch.from_numpy(train_x[:, None]).float()
    Y = torch.from_numpy(train_y[:, None]).float()

    model.set_dataset(X, Y)
    mmodel.set_dataset(X, Y)

    model = model.float()
    mmodel = mmodel.float()

    if args.use_cuda and torch.cuda.is_available():
        X = X.cuda()
        Y = Y.cuda()
        model = model.cuda()
        mmodel = mmodel.cuda()

    print(('Dataset size:', train_x.shape[0], 'samples'))
    # train unimodal regressor
    utils.train_regressor(model,
                          iters=args.train_iters,
                          batchsize=args.N_batch,
                          resample=args.resample,
                          optimizer=opt1,
                          log_likelihood=model.output_density.log_prob)

    # evaluate single gaussian model
    test_x = np.arange(-2.0, 1.5, 0.005)
    ret = []
    if args.resample:
        model.resample()
    for i, x in enumerate(test_x):
        x = torch.tensor(x[None]).float().to(model.X.device)
        outs = model(x.expand((2 * args.N_batch, 1)), resample=False)
        y = torch.cat(outs[:2], -1)
        ret.append(y.cpu().detach().numpy())
        torch.cuda.empty_cache()
    ret = np.stack(ret)
    ret = ret.transpose(1, 0, 2)
    torch.cuda.empty_cache()
    for i in range(3):
        gc.collect()

    plt.figure(figsize=(16, 9))
    nc = ret.shape[-2]
    colors = np.array(list(plt.cm.rainbow_r(np.linspace(0, 1, nc))))
    for i in range(len(ret)):
        m, logS = ret[i, :, 0], ret[i, :, 1]
        samples = gaussian_sample(m, logS)
        plt.scatter(test_x, m, c=colors[0:1], s=1)
        plt.scatter(test_x, samples, c=colors[0:1] * 0.5, s=1)
    plt.plot(test_x, f(test_x), linestyle='--', label='true function')
    plt.scatter(X.cpu().numpy().flatten(), Y.cpu().numpy().flatten())
    plt.xlabel('$x$', fontsize=18)
    plt.ylabel('$y$', fontsize=18)

    print(model)

    # train mixture regressor
    utils.train_regressor(mmodel,
                          iters=args.train_iters,
                          batchsize=args.N_batch,
                          resample=args.resample,
                          optimizer=opt2,
                          log_likelihood=mmodel.output_density.log_prob)

    # evaluate mixture density network
    test_x = np.arange(-2.0, 1.5, 0.005)
    ret = []
    logit_weights = []
    if args.resample:
        mmodel.resample()
    for i, x in enumerate(test_x):
        x = torch.tensor(x[None]).float().to(mmodel.X.device)
        outs = mmodel(x.expand((2 * args.N_batch, 1)), resample=False)
        y = torch.cat(outs[:2], -2)
        ret.append(y.cpu().detach().numpy())
        logit_weights.append(outs[2].cpu().detach().numpy())
        torch.cuda.empty_cache()
    ret = np.stack(ret)
    ret = ret.transpose(1, 0, 2, 3)
    logit_weights = np.stack(logit_weights)
    logit_weights = logit_weights.transpose(1, 0, 2)
    torch.cuda.empty_cache()
    for i in range(3):
        gc.collect()

    plt.figure(figsize=(16, 9))
    nc = ret.shape[-1]
    colors = np.array(list(plt.cm.rainbow_r(np.linspace(0, 1, nc))))
    total_samples = []
    for i in range(len(ret)):
        m, logS = ret[i, :, 0, :], ret[i, :, 1, :]
        samples, c = mixture_sample(m, logS, logit_weights[i], colors)
        plt.scatter(test_x, samples, c=c * 0.5, s=1)
        samples, c = mixture_sample(m,
                                    logS,
                                    logit_weights[i],
                                    colors,
                                    noise=False)
        plt.scatter(test_x, samples, c=c, s=1)
        total_samples.append(samples)
    total_samples = np.array(total_samples)
    plt.plot(test_x, f(test_x), linestyle='--', label='true function')
    plt.scatter(X.cpu().numpy().flatten(), Y.cpu().numpy().flatten())
    plt.xlabel('$x$', fontsize=18)
    plt.ylabel('$y$', fontsize=18)

    print(mmodel)

    plt.show()