コード例 #1
0
def main(config):
    """
    Construct meta-transform class, and add it as a new attribute to config under dataset
    """
    list_transform = [
        get_instance(module_transform, i, config) for i in config
        if 'transform' in i
    ]
    transform = transforms.Compose(
        list_transform)  # construct meta-transform class
    print(transform)
    config['dataset']['args'][
        'transform'] = transform  # add new attributes to dataset args
    """
    Construct dataset along with the added meta-tranform class
    """
    d = get_instance(
        module_dataset, 'dataset',
        config)  # construct dataset with added meta-transform class

    save_path = os.path.join(config['save_dir'], config['save_subdir'])
    ensure_dir(save_path)
    print("The processed audio will be saved at %s", save_path)
    config['dataset']['args'].pop(
        'transform',
        None)  # remove the added attributes since it prevents saving to json
    save_json(config, os.path.join(save_path, 'config_audio.json'))
    """
    Open a figure to draw 9 randomly sampled extracted spectrogram
    """
    np.random.seed(1234)
    display_samples = np.random.choice(len(d), size=9, replace=False)
    gs = gridspec.GridSpec(3, 3)
    fig = plt.figure(figsize=(15, 15))
    n_fig = 0

    start_time = time.time()
    for k in range(len(d)):
        print("Transforming %d-th data ... %s" % (k, d.path_to_data[k]))
        x, idx, fp = d[k]
        assert idx == k
        audio_id = fp.split('/')[-1].split('.')[0]
        p = os.path.join(save_path,
                         '%s-%s.%s' % (audio_id, config['save_subdir'], 'pth'))
        # np.save(p, x)
        torch.save(x, p)

        if k in display_samples:
            ax = fig.add_subplot(gs[n_fig])
            if len(x.size()) > 2:
                x = x.squeeze(0)
            ax.imshow(x, aspect='auto', origin='lower')
            # ax.set_yticklabels([])
            # ax.set_xticklabels([])
            ax.set_title(audio_id)
            n_fig += 1
            plt.savefig(os.path.join(save_path, '.'.join(['spec', 'jpg'])))

    plt.savefig(os.path.join(save_path, '.'.join(['spec', 'jpg'])))
    print("Time: %.2f seconds" % (time.time() - start_time))
コード例 #2
0
    def __init__(self, model, config):
        super(OurOptimizersAndSchedulers, self).__init__()

        self.encoder_optimizer = get_instance(
            optim, 'optimizer_encoder', config,
            get_trainable_params(model.encoder))
        self.decoder_optimizer = get_instance(
            optim, 'optimizer_decoder', config,
            get_trainable_params(model.decoder))
        self.code_generator_optimizer = get_instance(
            optim, 'optimizer_code_generator', config,
            get_trainable_params(model.code_generator))
        self.d_i_optimizer = get_instance(optim, 'optimizer_d_i', config,
                                          get_trainable_params(model.d_i))
        self.d_c_optimizer = get_instance(optim, 'optimizer_d_c', config,
                                          get_trainable_params(model.d_c))

        self.encoder_scheduler = get_instance(optim.lr_scheduler,
                                              'lr_scheduler_encoder', config,
                                              self.encoder_optimizer)
        self.decoder_scheduler = get_instance(optim.lr_scheduler,
                                              'lr_scheduler_decoder', config,
                                              self.decoder_optimizer)
        self.code_generator_scheduler = get_instance(
            optim.lr_scheduler, 'lr_scheduler_code_generator', config,
            self.code_generator_optimizer)
        self.d_i_scheduler = get_instance(optim.lr_scheduler,
                                          'lr_scheduler_d_i', config,
                                          self.d_i_optimizer)
        self.d_c_scheduler = get_instance(optim.lr_scheduler,
                                          'lr_scheduler_d_c', config,
                                          self.d_c_optimizer)
コード例 #3
0
def main(config, resume):
    torch.random.manual_seed(1986)
    torch.cuda.empty_cache()
    train_logger = Logger()

    # setup data_loader instances
    data_loader = get_instance(module_data, 'data_loader', config)
    valid_data_loader = data_loader.split_validation()

    # build model architecture
    model = get_instance(module_arch, 'arch', config)
    model.summary()

    # get function handles of loss and metrics
    loss = module_loss.OurLosses()
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler

    optimizers_and_schedulers = OurOptimizersAndSchedulers(model, config)

    trainer = Trainer(model,
                      loss,
                      metrics,
                      optimizers_and_schedulers,
                      resume=resume,
                      config=config,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      train_logger=train_logger)

    trainer.train()
コード例 #4
0
ファイル: classifier.py プロジェクト: zzll136/ModelFeast
    def __init__(self, model='xception', n_classes=10, img_size=(224, 224), data_dir = None,
                    pretrained=False, pretrained_path="./pretrained/", n_gpu=1,
                    default_init=True):
        """ 初始化时只初始化model model可以是string 也可以是自己创建的model """
        super(classifier, self).__init__()

        self.resume = None
        self.data_loader = None
        self.valid_data_loader = None
        self.train_logger = Logger()

        if isinstance(model, str):
            arch = {
                    "type": model, 
                    "args": {"n_class": n_classes, "img_size": img_size, 
                    "pretrained": pretrained, "pretrained_path": pretrained_path} 
                    }
            self.config = {"name": model, "arch": arch, "n_gpu":n_gpu}
            self.model = get_instance(model_zoo, 'arch', self.config)       
            # self.model = getattr(model_zoo, model)(n_classes, img_size, pretrained, pretrained_path)
        elif callable(model):
            model_name = model.__class__.__name__
            arch = {
                    "type": model_name, 
                    "args": {"n_class": n_classes, "img_size": img_size} 
                    }
            self.config = {"name": model_name, "arch": arch, "n_gpu":n_gpu}
            self.model = model         
        else:
            self.logger.info("input type is invalid, please set model as str or a callable object")
            raise Exception("model: wrong input error")

        if default_init:
            # self.loss = torch.nn.CrossEntropyLoss() #效果一样
            self.config["loss"] = "cls_loss"
            self.loss = getattr(module_loss, self.config["loss"])
            
            self.config["metrics"] = ["accuracy"]
            self.metrics = [getattr(module_metric, met) for met in self.config['metrics']]

            # build optimizer
            self.config["optimizer"] = {"type": "Adam", "args":{"lr": 0.0003, "weight_decay": 0.00003}}
            optimizer_params = filter(lambda p: p.requires_grad, self.model.parameters())
            self.optimizer = get_instance(torch.optim, 'optimizer', self.config, optimizer_params)

            self.config["lr_scheduler"] = {"type": "StepLR", "args": {"step_size": 50, "gamma": 0.2 }}
            self.lr_scheduler = get_instance(torch.optim.lr_scheduler, 'lr_scheduler', 
                self.config, self.optimizer)

            self.set_trainer()

        if data_dir:
            self.autoset_dataloader(data_dir, batch_size=64, shuffle=True, validation_split=0.2, 
                num_workers=4, transform = None)
コード例 #5
0
ファイル: classifier.py プロジェクト: zzll136/ModelFeast
    def init_from_config(config_file, resume=None, user_defined_model=None):
        '''
        if user_defined_model is callable, init classifier with the given model
        otherwise, use model in modelfeast
        '''
        
        config = json.load(open(config_file))
        model_config = config['arch']['args']

        if callable(user_defined_model):
            kernal_model = user_defined_model
        else:
            kernal_model = config['arch']['type']

        clf = classifier(model=kernal_model, n_classes=model_config['n_class'],
                        img_size=model_config['img_size'], data_dir = None, 
                        pretrained=model_config['pretrained'], 
                        pretrained_path=model_config['pretrained_path'], 
                        default_init=False)

        loss = getattr(module_loss, config['loss'])
        metrics = [getattr(module_metric, met) for met in config['metrics']]
        optimizer_params = filter(lambda p: p.requires_grad, clf.model.parameters())
        optimizer = get_instance(torch.optim, 'optimizer', config, optimizer_params)
        lr_scheduler = get_instance(torch.optim.lr_scheduler, 'lr_scheduler', config, optimizer)

        # setup data_loader config
        if 'args' in config['data_loader']:
            data_cng = config['data_loader']['args']
            transform = None if 'transform' not in data_cng else data_cng['transform']
            if not transform and config["arch"]["args"]["img_size"]:
                config['data_loader']['args']['transform'] = config["arch"]["args"]["img_size"]

        try:
            data_loader = get_instance(module_data, 'data_loader', config)
            valid_data_loader = data_loader.split_validation()
        except AttributeError:
            clf.logger.warning('Can not match data loader in config file!!! please set data loader manually!!!')
            data_loader = None
            valid_data_loader = None
            pass

        # set classifier according to those
        clf.config = config
        clf.loss = loss
        clf.resume = resume
        clf.metrics = metrics
        clf.optimizer = optimizer
        clf.lr_scheduler = lr_scheduler
        clf.data_loader = data_loader
        clf.valid_data_loader = valid_data_loader
        
        return clf
コード例 #6
0
    def __init__(self, **kwargs):
        super().__init__()

        config = Config(kwargs)
        self.minor_encoder = get_instance(config.minor_encoder)

        if hasattr(config, "major_encoder"):
            self.major_encoder = get_instance(config.major_encoder)
        elif hasattr(config, 'encoder_config_share'):
            self.major_encoder = get_instance(config.minor_encoder)
        else:
            self.major_encoder = None

        self.config = config
コード例 #7
0
ファイル: test.py プロジェクト: steve-smashnuk/Songify
def main(config, resume):
    # setup collate function
    collate_fn = getattr(module_collate, config['test_collate_fn'])

    # setup data_loader instances
    data_loader = get_instance(module_data_loader, 'test_data_loader', config, collate_fn=collate_fn)

    # build model architecture
    model = get_instance(module_model, 'model', config)
    model.summary()

    # get function handles of loss and metrics
    loss_fn = getattr(module_loss, config['loss'])
    metric_fns = [getattr(module_metric, met) for met in config['metrics']]

    # load state dict
    checkpoint = torch.load(resume)
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)

    # prepare model for testing
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    total_loss = 0.0
    total_metrics = torch.zeros(len(metric_fns))

    with torch.no_grad():
        for i, (data, target) in enumerate(tqdm(data_loader)):
            data, target = data.to(device), target.to(device)
            output = model(data)
            #
            # save sample images, or do something with output here
            #

            # computing loss, metrics on test set
            loss = loss_fn(output, target)
            batch_size = data.shape[0]
            total_loss += loss.item() * batch_size
            for i, metric in enumerate(metric_fns):
                total_metrics[i] += metric(output, target) * batch_size

    n_samples = len(data_loader.sampler)
    log = {'loss': total_loss / n_samples}
    log.update({met.__name__ : total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)})
    print(log)
コード例 #8
0
ファイル: tasks.py プロジェクト: morelab/ckanext-extractor
def perform_extraction(package_id, mainclass):
    full_task_id = perform_extraction.name + '_' + package_id
    
    if not is_running(full_task_id):
        add_task(full_task_id)
        
        t = session.query(Transformation).filter_by(package_id=package_id).first()

        #change to transformation directory
        os.chdir(t.output_dir)

        #create context and call transformation entry point
        context = ExtractionContext(t, session)

        log.info('Starting transformation %s' % package_id)
        transformation_instance = get_instance(t.output_dir, mainclass)

        try: 
            transformation_instance.start_transformation(context)
        except:
            comment = traceback.format_exc()
            context.finish_error(comment)
            log.info(comment)
            
        remove_task(full_task_id)
    else:
        print 'Extraction task is already running %s' % full_task_id
コード例 #9
0
ファイル: domain_name.py プロジェクト: scalefront/dyndns53
def generate_domain_name(product_tld, hosted_zone, subdomain=None, end_with_period=False):
    from utils import get_current_instance_id, get_instance, replace_parent_domain
    instance_id = get_current_instance_id()
    instance = get_instance(instance_id)
    tag_name = instance.tags['Name']
    domain_name = replace_parent_domain(tag_name, product_tld, hosted_zone, end_with_period)
    if subdomain:
        from utils import join_domain
        domain_name = join_domain(subdomain, domain_name)
    return domain_name
コード例 #10
0
def load_model_for_eval(checkpoint):
    config_file = Path(checkpoint).parent / 'config.json'
    config = read_json(config_file)
    model = get_instance(module_arch, 'arch', config)
    model.summary()
    checkpoint = torch.load(checkpoint, map_location='cpu')
    state_dict = checkpoint['state_dict']
    model.load_state_dict(clean_state_dict(state_dict))
    model.eval()
    return model
コード例 #11
0
def _retrieve_property(_ctx, property_name):
    property_from_client_config = get_node(_ctx).properties \
        .get('client_config', {}).get(property_name, {})
    target = _retrieve_master(get_instance(_ctx))

    if target:
        _ctx.logger.info("using property from managed_by_master"
                         " relationship for node: {0}, it will be deprecated"
                         " soon please use client_config property!".format(
                             _ctx.node.name))
        configuration = target.node.properties.get(property_name, {})
        configuration.update(
            target.instance.runtime_properties.get(property_name, {}))

    else:
        configuration = property_from_client_config
        configuration.update(
            get_instance(_ctx).runtime_properties.get(property_name, {}))

    return configuration
コード例 #12
0
ファイル: baseline.py プロジェクト: shyaoni/ctt
    def __init__(self, field, config):
        super().__init__()

        embedding = field.get_embedding_from_glove(config.embedding)
        self.embedding = nn.Embedding(
            num_embeddings=config.embedding.vocab_size + 4,
            embedding_dim=config.embedding.dim,
            padding_idx=0,
            _weight=embedding)

        self.context_encoder = get_instance(config.context_encoder)
        self.candidates_encoder = get_instance(config.candidates_encoder)

        transform_mat = torch.Tensor(truncnorm.rvs(-2, 2, size=(
            config.candidates_encoder.code_size,
            config.context_encoder.code_size)))
        transform_bias = torch.zeros(1)

        self.transform_mat = torch.nn.Parameter(transform_mat)
        self.transform_bias = torch.nn.Parameter(transform_bias)
        self.config = config
コード例 #13
0
def main(config, resume):
    train_logger = Logger()

    # setup collate function
    collate_fn = getattr(module_collate, config['collate_fn'])

    # setup data_loader instances
    data_loader = get_instance(module_data_loader,
                               'data_loader',
                               config,
                               collate_fn=collate_fn)
    valid_data_loader = data_loader.split_validation()

    # build model architecture
    model = get_instance(module_model, 'model', config)
    print(model)

    # get function handles of loss and metrics
    loss = getattr(module_loss, config['loss'])
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = get_instance(torch.optim, 'optimizer', config,
                             trainable_params)
    lr_scheduler = get_instance(torch.optim.lr_scheduler, 'lr_scheduler',
                                config, optimizer)

    trainer = Trainer(model,
                      loss,
                      metrics,
                      optimizer,
                      resume=resume,
                      config=config,
                      data_loader=data_loader,
                      valid_data_loader=valid_data_loader,
                      lr_scheduler=lr_scheduler,
                      train_logger=train_logger)

    trainer.train()
コード例 #14
0
def main(config, resume):
    torch.manual_seed(1234)
    train_logger = Logger()

    # setup data_loader instances
    data_loader = get_instance(module_data, 'data_loader', config)
    valid_data_loader = data_loader.split_validation()
    dd = np.setdiff1d(valid_data_loader.sampler.indices,
                      np.load('data/valid_idx.npy'))
    assert len(dd) == 0

    # build model architecture
    model = get_instance(module_arch, 'arch', config)
    model.apply(weights_init)
    print(model)

    # get function handles of loss and metrics
    loss = {
        l_i: get_instance(module_loss, l_i, config)
        for l_i in config if 'loss' in l_i
    }
    metrics = [getattr(module_metric, met) for met in config['metrics']]

    # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = get_instance(torch.optim, 'optimizer', config,
                             trainable_params)

    trainer = GMVAE_Trainer(model,
                            loss,
                            metrics,
                            optimizer,
                            resume=resume,
                            config=config,
                            data_loader=data_loader,
                            label_portion=config['trainer']['label_portion'],
                            valid_data_loader=valid_data_loader,
                            train_logger=train_logger)

    trainer.train()
コード例 #15
0
    def deploy_transformation(self, transformation):
        mainclass, required = get_config_data(transformation.output_dir)
        transformation_instance = get_instance(transformation.output_dir, mainclass)
        transformation_instance.deploy()

        #install depedencies using celery
        celery.send_task("extractor.install_dependencies",
            args=[required], task_id=str(uuid.uuid4()))

        #remove extraction log
        transformation.extractions = []
        model.Session.merge(transformation)
        model.Session.commit()
コード例 #16
0
ファイル: api.py プロジェクト: shyaoni/ctt
def chatbot(config_path):
    config = importlib.import_module(config_path).config
    dataset = get_dataset('dts_ConvAI2')
    vocab = dataset.get_vocab(config.data.vocab_size)
    corpus_set = dataset.get_set(config.data.corpus_set, vocab)
    predictor = utils.get_instance(config.model)
    predictor.load(config.save_path)
    chatbot = Chatbot(
        InferGraphAgent(predictor,
                        id='agent').corpus(corpus_set,
                                           config.data.corpus_set).build())
    chatbot.send_target('work')

    return chatbot
コード例 #17
0
import torch

import utils
from data import data_parser, get_dataset

import module.tokenizer
module.tokenizer.tokenize_pipeline()

torch.multiprocessing.set_sharing_strategy('file_system')

from module.chat import Sess, Chatbot, InferGraphAgent

if __name__ == '__main__':
    config, args = utils.get_config(data_parser())
    dataset = get_dataset(args.dts)

    vocab = dataset.get_vocab(config.data.vocab_size)
    corpus_set = dataset.get_set(config.data.corpus_set, vocab)

    predictor = utils.get_instance(config.model)
    predictor.load(config.save_path)

    chatbot = Chatbot(
        InferGraphAgent(predictor,
                        id='agent').corpus(corpus_set,
                                           config.data.corpus_set).build())

    from IPython import embed
    embed()
コード例 #18
0
def main(config):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for model_type in "SmallNet", "HourglassNet":
        for imwidth in 70, 96, 128:
            for keypoint_reg in False, True:
                for upsample in False, True:

                    if upsample and not keypoint_reg:
                        continue  # not needed
                    if model_type == "HourglassNet" and imwidth not in {96, 128}:
                        continue  # not needed
                    if model_type == "SmallNet" and imwidth not in {70, 128}:
                        continue  # not needed

                    profile_name = get_profile_name(
                        model_type=model_type,
                        keypoint_reg=keypoint_reg,
                        upsample=upsample,
                        imwidth=imwidth,
                    )
                    config["dataset"]["args"] = {"imwidth": imwidth}
                    val_dataset = get_instance(
                        module=module_data,
                        name='dataset',
                        config=config,
                        train=False,
                    )
                    val_loader = DataLoader(
                        val_dataset,
                        batch_size=config["batch_size"],
                    )
                    config["arch"] = {
                        "type": model_type,
                        "args": {"num_output_channels": 64},
                    }

                    model = get_instance(module_arch, 'arch', config)

                    if keypoint_reg:
                        descdim = config['arch']['args']['num_output_channels']
                        kp_regressor = get_instance(module_arch, 'keypoint_regressor',
                                                    config,
                                                    descriptor_dimension=descdim)
                        basemodel = NoGradWrapper(model)

                        if upsample:
                            model = nn.Sequential(basemodel, Up(), kp_regressor)
                        else:
                            model = nn.Sequential(basemodel, kp_regressor)
                    # model.summary()

                    # prepare model for testing
                    model = model.to(device)
                    model.eval()
                    timings = []
                    warmup = 3
                    num_batches = 10

                    with torch.no_grad():
                        # count = 0
                        tic = time.time()
                        for ii, batch in enumerate(val_loader):
                            data = batch["data"].to(device)
                            batch_size = data.shape[0]
                            _ = model(data)
                            speed = batch_size / (time.time() - tic)
                            if ii > warmup:
                                timings.append(speed)
                            if ii > warmup + num_batches:
                                break
                            # print("speed: {:.3f}Hz".format(speed))
                            tic = time.time()
                            # count += batch_size

                    flops, params = profile(
                        model,
                        input_size=(1, 3, imwidth, imwidth),
                        verbose=False,
                    )
                    # use format so that its easy to latexify
                    template = "{} & {:.1f} & {:.1f} & ${:.1f} (\\pm {:.1f})$\\\\"
                    template = template.format(
                        profile_name,
                        params / 10**6,
                        flops / 10**9,
                        np.mean(timings),
                        np.std(timings),
                    )
                    print(template)
コード例 #19
0
ファイル: train.py プロジェクト: abutaufique/DVE
def main(config, resume):
    logger = config.get_logger('train')
    seeds = [int(x) for x in config._args.seeds.split(",")]
    torch.backends.cudnn.benchmark = True
    logger.info("Launching experiment with config:")
    logger.info(config)

    if len(seeds) > 1:
        run_metrics = []

    for seed in seeds:
        tic = time.time()
        logger.info(f"Setting experiment random seed to {seed}")
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)

        model = get_instance(module_arch, 'arch', config)
        logger.info(model)

        if 'finetune_from' in config.keys():
            checkpoint = torch.load(config['finetune_from'])
            model.load_state_dict(clean_state_dict(checkpoint["state_dict"]))
            print('Finetuning from %s' % config['finetune_from'])

        if 'keypoint_regressor' in config.keys():
            descdim = config['arch']['args']['num_output_channels']
            kp_regressor = get_instance(module_arch,
                                        'keypoint_regressor',
                                        config,
                                        descriptor_dimension=descdim)
            basemodel = NoGradWrapper(model)

            if config.get('keypoint_regressor_upsample', False):
                model = nn.Sequential(basemodel, Up(), kp_regressor)
            else:
                model = nn.Sequential(basemodel, kp_regressor)

        if 'segmentation_head' in config.keys():
            descdim = config['arch']['args']['num_output_channels']
            segmenter = get_instance(module_arch,
                                     'segmentation_head',
                                     config,
                                     descriptor_dimension=descdim)
            if config["segmentation_head"]["args"].get("freeze_base", True):
                basemodel = NoGradWrapper(model)
            else:
                basemodel = model

            if config.get('segmentation_upsample', False):
                model = nn.Sequential(basemodel, Up(), segmenter)
            else:
                model = nn.Sequential(basemodel, segmenter)

        # setup data_loader instances
        imwidth = config['dataset']['args']['imwidth']
        warper = get_instance(tps, 'warper', config, imwidth,
                              imwidth) if 'warper' in config.keys() else None

        loader_kwargs = {}
        coll_func = config.get("collate_fn", "dict_flatten")
        if coll_func == "flatten":
            loader_kwargs["collate_fn"] = coll
        elif coll_func == "dict_flatten":
            loader_kwargs["collate_fn"] = dict_coll
        else:
            raise ValueError(
                "collate function type {} unrecognised".format(coll_func))

        dataset = get_instance(module_data,
                               'dataset',
                               config,
                               pair_warper=warper,
                               train=True)
        if config["disable_workers"]:
            num_workers = 0
        else:
            num_workers = 4

        if config.get("restrict_annos", False):
            dataset.restrict_annos(num=config["restrict_annos"])
            logger.info(
                f"restricting annotation to {config['restrict_annos']} samples..."
            )

        data_loader = DataLoader(
            dataset,
            batch_size=int(config["batch_size"]),
            num_workers=num_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            **loader_kwargs,
        )

        warp_val = config.get('warp_val', True)
        val_dataset = get_instance(
            module_data,
            'dataset',
            config,
            train=False,
            pair_warper=warper if warp_val else None,
        )
        valid_data_loader = DataLoader(val_dataset,
                                       batch_size=32,
                                       **loader_kwargs)

        # get function handles of loss and metrics
        loss = getattr(module_loss, config['loss'])
        metrics = [getattr(module_metric, met) for met in config['metrics']]
        if not config["vis"]:
            visualizations = []
        else:
            visualizations = [
                getattr(module_visualization, vis)
                for vis in config['visualizations']
            ]

        # build optimizer, learning rate scheduler. delete every lines containing
        # lr_scheduler for disabling scheduler
        trainable_params = list(
            filter(lambda p: p.requires_grad, model.parameters()))

        if 'keypoint_regressor' in config.keys():
            base_params = list(
                filter(lambda p: p.requires_grad, basemodel.parameters()))
            trainable_params = [
                x for x in trainable_params
                if not sum([(x is w) for w in base_params])
            ]

        biases = [x.bias for x in model.modules() if isinstance(x, nn.Conv2d)]

        trainbiases = [
            x for x in trainable_params if sum([(x is b) for b in biases])
        ]
        trainweights = [
            x for x in trainable_params if not sum([(x is b) for b in biases])
        ]
        print(len(trainbiases), 'Biases', len(trainweights), 'Weights')

        bias_lr = config.get('bias_lr', None)
        if bias_lr is not None:
            optimizer = get_instance(torch.optim, 'optimizer', config,
                                     [{
                                         "params": trainweights
                                     }, {
                                         "params": trainbiases,
                                         "lr": bias_lr
                                     }])
        else:
            optimizer = get_instance(torch.optim, 'optimizer', config,
                                     trainable_params)

        lr_scheduler = get_instance(torch.optim.lr_scheduler, 'lr_scheduler',
                                    config, optimizer)
        trainer = Trainer(
            model=model,
            loss=loss,
            metrics=metrics,
            resume=resume,
            config=config,
            optimizer=optimizer,
            data_loader=data_loader,
            lr_scheduler=lr_scheduler,
            visualizations=visualizations,
            mini_train=config._args.mini_train,
            valid_data_loader=valid_data_loader,
        )
        trainer.train()
        duration = time.strftime('%Hh%Mm%Ss', time.gmtime(time.time() - tic))
        logger.info(f"Training took {duration}")
        if "keypoint_regressor" not in config.keys():
            epoch = config["trainer"]["epochs"]
            config._args.resume = config.save_dir / f"checkpoint-epoch{epoch}.pth"
            config["mini_eval"] = config._args.mini_train
            evaluation(config, logger=logger)
            logger.info(f"Log written to {config.log_path}")
        elif "keypoint_regressor" in config.keys() and len(seeds) > 1:
            run_metrics.append(copy.deepcopy(trainer.latest_log))

    if len(seeds) > 1 and "keypoint_regressor" in config.keys():
        target = "val_inter_ocular_error"
        errors = [x[target] for x in run_metrics]
        logger.info(
            f"{target} -> mean: {np.mean(errors)}, std: {np.std(errors)}")
コード例 #20
0
def evaluation(config, logger=None, eval_data=None):
    device = torch.device('cuda:0' if config["n_gpu"] > 0 else 'cpu')

    if logger is None:
        logger = config.get_logger('test')

    logger.info("Running evaluation with configuration:")
    logger.info(config)

    imwidth = config['dataset']['args']['imwidth']
    root = config["dataset"]["args"]["root"]
    warp_crop_default = config['warper']['args'].get('crop', None)
    crop = config['dataset']['args'].get('crop', warp_crop_default)

    # Want explicit pair warper
    disable_warps = True
    dense_match = config.get("dense_match", False)
    if dense_match and disable_warps:
        # rotsd = 2.5
        # scalesd=0.1 * .5
        rotsd = 0
        scalesd = 0
        warp_kwargs = dict(warpsd_all=0,
                           warpsd_subset=0,
                           transsd=0,
                           scalesd=scalesd,
                           rotsd=rotsd,
                           im1_multiplier=1,
                           im1_multiplier_aff=1)
    else:
        warp_kwargs = dict(warpsd_all=0.001 * .5,
                           warpsd_subset=0.01 * .5,
                           transsd=0.1 * .5,
                           scalesd=0.1 * .5,
                           rotsd=5 * .5,
                           im1_multiplier=1,
                           im1_multiplier_aff=1)
    warper = tps.Warper(imwidth, imwidth, **warp_kwargs)
    if eval_data is None:
        eval_data = config["dataset"]["type"]
    constructor = getattr(module_data, eval_data)

    # handle the case of the MAFL split, which by default will evaluate on Celeba
    kwargs = {
        "val_split": "mafl"
    } if eval_data == "CelebAPrunedAligned_MAFLVal" else {}
    val_dataset = constructor(
        train=False,
        pair_warper=warper,
        use_keypoints=True,
        imwidth=imwidth,
        crop=crop,
        root=root,
        **kwargs,
    )
    # NOTE: Since the matching is performed with pairs, we fix the ordering and then
    # use all pairs for datasets with even numbers of images, and all but one for
    # datasets that have odd numbers of images (via drop_last=True)
    data_loader = DataLoader(val_dataset,
                             batch_size=2,
                             collate_fn=dict_coll,
                             shuffle=False,
                             drop_last=True)

    # build model architecture
    model = get_instance(module_arch, 'arch', config)
    model.summary()

    # load state dict
    ckpt_path = config._args.resume
    logger.info(f"Loading checkpoint: {ckpt_path} ...")
    checkpoint = torch.load(ckpt_path)
    # checkpoint = torch.load(config["weights"])
    state_dict = checkpoint['state_dict']
    if config['n_gpu'] > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(clean_state_dict(state_dict))
    if config['n_gpu'] > 1:
        model = model.module

    model = model.to(device)
    model.train()

    if dense_match:
        warp_dir = Path(config["warp_dir"]) / config["name"]
        warp_dir = warp_dir / "disable_warps{}".format(disable_warps)
        if not warp_dir.exists():
            warp_dir.mkdir(exist_ok=True, parents=True)
        writer = SummaryWriter(warp_dir)

    model.eval()
    same_errs = []
    diff_errs = []

    torch.manual_seed(0)
    with torch.no_grad():
        for i, batch in enumerate(tqdm(data_loader)):
            data, meta = batch["data"], batch["meta"]

            if (config.get("mini_eval", False) and i > 3):
                break
            # if i == 0:
            #     # Checksum to make sure warps are deterministic
            #     if True:
            #         # redo later
            #         if data.shape[2] == 64:
            #             assert float(data.sum()) == -553.9221801757812
            #         elif data.shape[2] == 128:
            #             assert float(data.sum()) == 754.1907348632812

            data = data.to(device)
            output = model(data)

            descs = output[0]
            descs1 = descs[0::2]  # 1st in pair (more warped)
            descs2 = descs[1::2]  # 2nd in pair
            ims1 = data[0::2].cpu()
            ims2 = data[1::2].cpu()

            im_source = ims1[0]
            im_same = ims2[0]
            im_diff = ims2[1]

            C, imH, imW = im_source.shape
            B, C, H, W = descs1.shape
            stride = imW / W

            desc_source = descs1[0]
            desc_same = descs2[0]
            desc_diff = descs2[1]

            if not dense_match:
                kp1 = meta['kp1']
                kp2 = meta['kp2']
                kp_source = kp1[0]
                kp_same = kp2[0]
                kp_diff = kp2[1]

            if config.get("vis", False):
                fig = plt.figure()  # a new figure window
                ax1 = fig.add_subplot(1, 3, 1)
                ax2 = fig.add_subplot(1, 3, 2)
                ax3 = fig.add_subplot(1, 3, 3)

                ax1.imshow(norm_range(im_source).permute(1, 2, 0))
                ax2.imshow(norm_range(im_same).permute(1, 2, 0))
                ax3.imshow(norm_range(im_diff).permute(1, 2, 0))

                if not dense_match:
                    ax1.scatter(kp_source[:, 0], kp_source[:, 1], c='g')
                    ax2.scatter(kp_same[:, 0], kp_same[:, 1], c='g')
                    ax3.scatter(kp_diff[:, 0], kp_diff[:, 1], c='g')

            if False:
                fsrc = F.normalize(desc_source, p=2, dim=0)
                fsame = F.normalize(desc_same, p=2, dim=0)
                fdiff = F.normalize(desc_diff, p=2, dim=0)
            else:
                fsrc = desc_source.clone()
                fsame = desc_same.clone()
                fdiff = desc_diff.clone()

            if dense_match:
                # if False:
                #     print("DEBUGGING WITH IDENTICAL FEATS")
                #     fdiff = fsrc
                # tic = time.time()
                grid = dense_desc_match(fsrc, fdiff)
                im_warped = F.grid_sample(im_source.view(1, 3, imH, imW), grid)
                im_warped = im_warped.squeeze(0)
                # print("done matching in {:.3f}s".format(time.time() - tic))
                plt.close("all")
                if config["subplots"]:
                    fig = plt.figure()  # a new figure window
                    ax1 = fig.add_subplot(1, 3, 1)
                    ax2 = fig.add_subplot(1, 3, 2)
                    ax3 = fig.add_subplot(1, 3, 3)
                    ax1.imshow(norm_range(im_source).permute(1, 2, 0))
                    ax2.imshow(norm_range(im_diff).permute(1, 2, 0))
                    ax3.imshow(norm_range(im_warped).permute(1, 2, 0))
                    triplet_dest = warp_dir / "triplet-{:05d}.jpg".format(i)
                    fig.savefig(triplet_dest)
                else:
                    triplet_dest_dir = warp_dir / "triplet-{:05d}".format(i)
                    if not triplet_dest_dir.exists():
                        triplet_dest_dir.mkdir(exist_ok=True, parents=True)
                    for jj, im in enumerate((im_source, im_diff, im_warped)):
                        plt.axis("off")
                        fig = plt.figure(figsize=(1.5, 1.5))
                        ax = plt.Axes(fig, [0., 0., 1., 1.])
                        ax.set_axis_off()
                        fig.add_axes(ax)
                        # ax.imshow(data, cmap = plt.get_cmap("bone"))
                        im_ = norm_range(im).permute(1, 2, 0)
                        ax.imshow(im_)
                        dest_path = triplet_dest_dir / "im-{}-{}.jpg".format(
                            i, jj)
                        plt.savefig(str(dest_path), dpi=im_.shape[0])
                        # plt.savefig(filename, dpi = sizes[0])
                writer.add_figure('warp-triplets', fig)
            else:
                for ki, kp in enumerate(kp_source):
                    x, y = np.array(kp)
                    gt_same_x, gt_same_y = np.array(kp_same[ki])
                    gt_diff_x, gt_diff_y = np.array(kp_diff[ki])
                    same_x, same_y = find_descriptor(x, y, fsrc, fsame, stride)

                    err = compute_pixel_err(
                        pred_x=same_x,
                        pred_y=same_y,
                        gt_x=gt_same_x,
                        gt_y=gt_same_y,
                        imwidth=imwidth,
                        crop=crop,
                    )
                    same_errs.append(err)
                    diff_x, diff_y = find_descriptor(x, y, fsrc, fdiff, stride)
                    err = compute_pixel_err(
                        pred_x=diff_x,
                        pred_y=diff_y,
                        gt_x=gt_diff_x,
                        gt_y=gt_diff_y,
                        imwidth=imwidth,
                        crop=crop,
                    )
                    diff_errs.append(err)
                    if config.get("vis", False):
                        ax2.scatter(same_x, same_y, c='b')
                        ax3.scatter(diff_x, diff_y, c='b')

            if config.get("vis", False):
                zs_dispFig()
                fig.savefig('/tmp/matching.pdf')

    print("")  # cleanup print from tqdm subtraction
    logger.info("Matching Metrics:")
    logger.info(f"Mean Pixel Error (same-identity): {np.mean(same_errs)}")
    logger.info(f"Mean Pixel Error (different-identity) {np.mean(diff_errs)}")
コード例 #21
0
ファイル: views.py プロジェクト: dimyG/zakanda_public
def create_bet_tree(user,
                    bet_slip_events_list,
                    bookmaker,
                    bet_amount,
                    bet_tag,
                    bet_description,
                    bet_description_url,
                    date,
                    is_past=False):
    logger.info('initiating total bet tree transaction...')
    logger.debug(
        'user: %s, bookmaker: %s, bet_tag: %s, bet_tag_balance_snapshot: %s, date: %s (%s), '
        'amount: %s, description: %s (%s), url: %s (%s), is_past: %s', user,
        bookmaker, bet_tag, bet_tag.balance, date, type(date), bet_amount,
        bet_description, type(bet_description), bet_description_url,
        type(bet_description_url), is_past)
    try:
        with transaction.atomic():
            current_tag_balance = bet_tag.balance
            total_bet = games.models.TotalBet.objects.create(
                user=user,
                bookmaker=bookmaker,
                bet_tag=bet_tag,
                bet_tag_balance_snapshot=current_tag_balance,
                date=date,
                amount=bet_amount,
                description=bet_description,
                url=bet_description_url,
                is_past=is_past)

            # TODO BET SYSTEM A function that Checks the bet system and does the following for each individual bet of the
            # "total" bet.It returns a list of "bet_slip_events_list" like lists for each bet. For each list we do the following
            # As it is now: If there is a split asian_handicap then I split the bet in two. I create one total_bet that has two
            # bets with half the bet_amount each. If there are more than one split asian_handicaps then we create multiple bets.
            # If in the future we add bet systems. Then if the user selects a bet system for 3 events (1 triple and 3 doubles)
            # then if there is a split asian_handicap we split the bet that has it in two. So in the previous case the triple
            # bet will split in two with half the amount, and 2 of the doubles will be split in two. So this total bet will have
            # 7 total bets instead of 4. First we extract the bets of the bet_system and on those bets we check for split asian.

            # split_asian_handicap_in_bet_slip = is_there_split_asian_handicap(bet_slip_events_list)
            split_asian_handicap_in_bet_slip = False
            bets = []
            if split_asian_handicap_in_bet_slip:
                # create_asian_handicap_bets(bet_slip_events_list, total_bet, bet_amount)
                bet_events = []  # just to remove the pycharm warning
                pass
            else:
                bet_events, bet_events_created, bet_events_total_odd = create_bet_events(
                    bet_slip_events_list)
                if not bet_events:
                    # So the transaction will roll back
                    raise utils.BetEventsDontExist
                # We create a new bet only if there is no bet already connected with these bet_events
                # (with these bet_events exactly) If it is connected also with some other bet_events then a new one will be used
                bet = None
                if True not in bet_events_created:
                    # if all bet_events already existed
                    bets_with_same_bet_events = utils.get_exact_m2m_match(
                        games.models.Bet, 'bet_events', bet_events)
                    matched_bet = utils.get_instance(games.models.Bet,
                                                     bets_with_same_bet_events,
                                                     amount=bet_amount,
                                                     odd=bet_events_total_odd)
                    bet = matched_bet
                if not bet:
                    # if bet = None or matched_bet = None
                    bet, bet_odd = create_bet(bet_amount, bet_events)
                total_bet.bets.add(bet)
                bets.append(bet)
            total_bet_odd = total_bet.update_odd()
    except Exception as e:
        logger.error(
            "Exception on total bet creation! %s. transaction rolled back", e)
        total_bet, bets, bet_events = (None, None, None)
        return total_bet, bets, bet_events
    return total_bet, bets, bet_events
コード例 #22
0
ファイル: classifier.py プロジェクト: zzll136/ModelFeast
 def set_optimizer(self, name="Adam", **kwargs):
     self.config["optimizer"] = {"type": name, "args":kwargs}
     optimizer_params = filter(lambda p: p.requires_grad, self.model.parameters())
     # print(len(list(optimizer_params)))
     self.optimizer = get_instance(torch.optim, 'optimizer', self.config, optimizer_params)
コード例 #23
0
ファイル: train.py プロジェクト: shyaoni/ctt
import torch

import utils
from data import data_parser, get_dataset

import module.tokenizer
module.tokenizer.tokenize_pipeline()

if __name__ == '__main__':
    config, args = utils.get_config(data_parser())
    dataset = get_dataset(args.dts)

    vocab = dataset.get_vocab(config.data.vocab_size)
    train_set = dataset.get_set(config.data.train_set, vocab)
    test_set = dataset.get_set(config.data.test_set, vocab)

    predictor = utils.get_instance(config.model, train_set.field)

    predictor.train(train_set, config, dts_valid=test_set)
    predictor.save(config.save_path)
コード例 #24
0
ファイル: test.py プロジェクト: criminalking/COMP790
base_name = config['name']
if args.cuda and not torch.cuda.is_available():
    raise Exception("No GPU found, please run without --cuda")
device = torch.device("cuda" if args.cuda else "cpu")

# load dataset
dataset = get_dataset(config['data_loader']['dataset'], config['data_loader']['use_cameras'], config['data_loader']['input_size'])
testing_data_loader = DataLoader(dataset=dataset, num_workers=config['data_loader']['num_workers'], batch_size=config['data_loader']['test_batch_size'], shuffle=False)

# restore checkpoint
print('===> Restoring model')
module = importlib.import_module('models.{}'.format(config['arch']['type']))
model = module.Model(len(config['data_loader']['use_cameras'])//2).to(device)
model_path = os.path.join(config["trainer"]['checkpoint_dir'], config["trainer"]['restore'])
model.load_state_dict(torch.load(model_path)['state_dict'])
criterion = get_instance(nn, "loss", config)

def validate():
    model.eval()

    targets = np.zeros((len(dataset), 3))
    predictions = np.zeros((len(dataset), 3))

    test_batch_size = config['data_loader']['test_batch_size']
    
    avg_error = 0
    with torch.no_grad():
        for i, batch in enumerate(testing_data_loader):
            input, target = batch['images'].to(device, dtype=torch.float), batch['gaze'].to(device, dtype=torch.float) # input are 6 images, left 3 + right 3

            prediction = model(input)
コード例 #25
0
ファイル: train.py プロジェクト: criminalking/COMP790
        batch_size=config['data_loader']['batch_size'],
        shuffle=True)
    testing_data_loader = DataLoader(
        dataset=test_dataset,
        num_workers=config['data_loader']['num_workers'],
        batch_size=config['data_loader']['test_batch_size'])

print('===> Building model')
module = importlib.import_module('models.{}'.format(config['arch']['type']))
model = module.Model(len(config['data_loader']['use_cameras']) // 2).to(device)
if config["trainer"]['restore']:
    print('===> Restoring model')
    model_path = os.path.join(config["trainer"]['checkpoint_dir'],
                              config["trainer"]['restore'])
    model.load_state_dict(torch.load(model_path)['state_dict'])
criterion = get_instance(nn, "loss", config)
optimizer = get_instance(optim, "optimizer", config, model.parameters())
scheduler = get_instance(optim.lr_scheduler, "lr_scheduler", config, optimizer)


def train(epoch, last_loss):
    epoch_loss = 0
    model.train()

    end_total = 0
    for iteration, batch in enumerate(training_data_loader, 1):
        start = time.time()
        end_data = start - end_total

        input, target = batch['images'].to(
            device, dtype=torch.float), batch['gaze'].to(
コード例 #26
0
ファイル: __init__.py プロジェクト: shyaoni/ctt
def get_dataset(name, *args, **kwargs):
    return get_instance('data.' + name, *args, **kwargs)
コード例 #27
0
import yaml
import datasets
from utils import get_instance

with open('../configs/newcode.yaml') as stream:
    config = yaml.load(stream)

cel = get_instance(datasets, config, 'data', 'dataset', train=False)
print(cel)
print(len(cel))
コード例 #28
0
def main(config, resume):
    logger = config.get_logger('train')
    seeds = [int(x) for x in config._args.seeds.split(",")]
    torch.backends.cudnn.benchmark = True

    # print information in logger
    logger.info("Launching experiment with config:")
    logger.info(config)

    # what are the seeds?
    if len(seeds) > 1:
        run_metrics = []

    for seed in seeds:
        tic = time.time()

        # use manual seed
        if True:
            logger.info(f"Setting experiment random seed to {seed}")
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)

        ## instantialize the model and print its info
        model = get_instance(module_arch, 'arch', config)
        logger.info(model)

        loader_kwargs = {}
        coll_func = config.get("collate_fn", "dict_flatten")
        if coll_func == "flatten":
            loader_kwargs["collate_fn"] = coll
        elif coll_func == "dict_flatten":
            loader_kwargs["collate_fn"] = dict_coll
        else:
            raise ValueError(
                "collate function type {} unrecognised".format(coll_func))

        dataset = get_instance(module_pc_data,
                               'dataset',
                               config,
                               split='train',
                               has_warper=True)
        if config["disable_workers"]:
            num_workers = 0
        else:
            num_workers = 4
        data_loader = DataLoader(
            dataset,
            batch_size=int(config["batch_size"]),
            num_workers=num_workers,
            shuffle=True,
            drop_last=True,
            pin_memory=True,
            **loader_kwargs,
        )
        val_dataset = get_instance(module_pc_data,
                                   'dataset',
                                   config,
                                   split='test')
        valid_data_loader = DataLoader(val_dataset,
                                       batch_size=1,
                                       **loader_kwargs)

        # get function handles of loss and metrics
        loss = getattr(module_loss, config['loss'])
        loss_args = config['loss_args']
        loss = loss(**loss_args)

        metrics = [getattr(module_metric, met) for met in config['metrics']]

        ## model parameter statisitcs
        trainable_params = list(
            filter(lambda p: p.requires_grad, model.parameters()))
        biases = [
            x.bias for x in model.modules() if isinstance(x, nn.Conv2d)
            or isinstance(x, nn.Conv1d) or isinstance(x, nn.Linear)
        ]
        weights = [
            x.weight for x in model.modules() if isinstance(x, nn.Conv2d)
            or isinstance(x, nn.Conv1d) or isinstance(x, nn.Linear)
        ]
        trainbiases = [
            x for x in trainable_params if sum([(x is b) for b in biases])
        ]
        trainweights = [
            x for x in trainable_params if sum([(x is w) for w in weights])
        ]
        trainparams = [
            x for x in trainable_params
            if not sum([(x is b)
                        for b in biases]) and not sum([(x is w)
                                                       for w in weights])
        ]
        print(len(trainparams), 'Parameters', len(trainbiases), 'Biases',
              len(trainweights), 'Weights', len(trainable_params),
              'Total Params')

        ## set different lr to weight and bias
        bias_lr = config.get('bias_lr', None)
        other_lr = config.get('other_lr', None)
        if bias_lr is not None and other_lr is not None:
            print("bias_lr is not None and other_lr is not None")
            optimizer = get_instance(
                torch.optim,
                'optimizer',
                config,
                [
                    # using the default learning rate
                    {
                        "params": trainweights
                    },
                    # using different learning rates
                    {
                        "params": trainbiases,
                        "lr": bias_lr
                    },
                    {
                        "params": trainparams,
                        "lr": other_lr
                    }
                ])
        elif bias_lr is not None:
            optimizer = get_instance(
                torch.optim,
                'optimizer',
                config,
                [
                    # using the default learning rate
                    {
                        "params": trainweights + trainparams
                    },
                    # using different learning rates
                    {
                        "params": trainbiases,
                        "lr": bias_lr
                    },
                ])
        else:
            optimizer = get_instance(torch.optim, 'optimizer', config,
                                     trainable_params)

        ## scheduler
        lr_scheduler = get_instance(torch.optim.lr_scheduler, 'lr_scheduler',
                                    config, optimizer)

        ## log dir
        print("config.log_dir: ", config.log_dir)
        print("config.result_dir: ", config.result_dir)

        ## pointcloud trainer
        trainer = PCTrainer(
            model=model,
            loss=loss,
            metrics=metrics,  # metrics
            resume=resume,  # resume path
            config=config,  # config structure
            optimizer=optimizer,
            data_loader=data_loader,
            lr_scheduler=lr_scheduler,
            mini_train=config._args.mini_train,
            valid_data_loader=
            valid_data_loader,  # no need for validation this time
        )
        trainer.train()
        duration = time.strftime('%Hh%Mm%Ss', time.gmtime(time.time() - tic))
        logger.info(f"Training took {duration}")
コード例 #29
0
ファイル: base.py プロジェクト: hcs/hcs-cloud
    def check_hostname(ec2, hostname):
        ''' Returns False if hostname is already in use, True otherwise '''

        return utils.get_instance(ec2, lambda i: "Name" in i.tags and i.tags['Name'] == hostname) is None
コード例 #30
0
ファイル: base.py プロジェクト: hcs/hcs-cloud
    def check_ip(ec2, ip):
        ''' Returns False if IP is alread in use, True otherwise '''

        return utils.get_instance(ec2, lambda i: i.private_ip_address == ip) is None
コード例 #31
0
ファイル: simulate.py プロジェクト: shyaoni/ctt
import torch

import utils
from data import data_parser, get_dataset

import module.tokenizer
module.tokenizer.tokenize_pipeline()

import module.chat as chat

if __name__ == '__main__':
    config, args = utils.get_config(data_parser())
    dataset = get_dataset(args.dts)

    vocab = dataset.get_vocab(config.vocab_size)
    predictor = utils.get_instance(config.predictor_type, None, config.model)
    predictor.load(config.save_path)

    train_set = dataset.utterance(vocab, train=True)

    chatbot = chat.Chatbot(
        chat.InferGraphAgent(predictor).corpus(train_set,
                                               config).build(width=100))

    from IPython import embed
    embed()
コード例 #32
0
ファイル: classifier.py プロジェクト: zzll136/ModelFeast
 def set_lr_scheduler(self, name="StepLR", **kwargs):
     self.config["lr_scheduler"] = {"type": name, "args":kwargs}
     self.lr_scheduler = get_instance(torch.optim.lr_scheduler, 'lr_scheduler', 
         self.config, self.optimizer)