示例#1
0
文件: meiju.py 项目: zzcv/python
def main():
    print("欢迎使用 美剧天堂 爬取脚本")
    print("=" * 20)
    print("魔幻/科幻:1\n灵异/惊悚:2\n都市/感情:3\n犯罪/历史:4\n选秀/综艺:5\n动漫/卡通:6")
    print("=" * 20)
    ftype = input('请输入需要爬取的类型的代号:')
    start_url = "http://www.meijutt.com/file/list%s.html" % ftype
    ourl = openurl.OpenUrl(start_url, 'gb2312')
    code, doc = ourl.openurl()
    mylog = Logger(
        os.path.join(os.path.abspath(os.path.curdir), 'misc/spider_log.yaml'))
    logger = mylog.outputLog()
    if code == 200:
        selecter = etree.HTML(doc)
        pages = selecter.xpath(
            "//div[@class='page']/span/text()")[0].split()[0].split('/')[1]
        firstpage_links = selecter.xpath("//a[@class='B font_14']/@href")
        for firstpage_link in firstpage_links:
            name, download_links = get_downlink(firstpage_link)
            send_mysql(name, download_links, logger)
            time.sleep(0.5)

        for page in range(2, int(pages)):
            page_url = 'http://www.meijutt.com/file/list%s_%s.html' % (ftype,
                                                                       page)
            for link in page_link(page_url):
                name, download_links = get_downlink(link)
                if name != '' and download_links != '':
                    send_mysql(name, download_links, logger)
                    time.sleep(0.5)
    else:
        print("[%s] error..." % start_url)

    print("Done.")
示例#2
0
    def __init__(self, bam, output, context_file):
        """Parameter initialization"""
        self.bam_file = bam
        self.in_bam = pysam.AlignmentFile(bam, "rb")
        self.filtered_bam = pysam.AlignmentFile(output,
                                                "wb",
                                                template=self.in_bam)
        # a list of ints is used to count pairs assigned to different filters
        # 0: count_input_alignments
        # 1: count_input_pairs
        # 2: count_filtered_pairs
        # 3: count_multimapped
        # 4: count_star_chimeric_alignments
        # 5: count_qcd_alignments
        # 6: count_unmapped
        # 7: count_10bp_s_clip
        # 8: count_proper_pair
        # 9: count_not_filtered_but_in_fusion_gene
        self.counter = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        self.last_time = 0
        self.logger = Logger("{}.fusionReadFilterLog".format(output))

        self.coord_dict = {}
        self.get_ranges(context_file)
        print(self.coord_dict)
示例#3
0
文件: piaohua.py 项目: zzcv/python
 def __init__(self, ftype):
     self.__ftype = ftype
     self.__redis_link = self.__redis_connect()
     mylog = Logger(
         os.path.join(os.path.abspath(os.path.curdir),
                      'misc/spider_log.yaml'))
     self.__logger = mylog.outputLog()
示例#4
0
文件: runningman.py 项目: zzcv/python
def main():
    mylog = Logger(
        os.path.join(os.path.abspath(os.path.curdir), 'misc/spider_log.yaml'))
    logger = mylog.outputLog()
    year = input("请输入年份:")
    allurl = get_links(year)
    downurl(allurl, logger)
示例#5
0
 def __init__(self, cmd, input_paths, working_dir):
     """Parameter initiation and work folder creation. Start of progress logging."""
     self.working_dir = os.path.abspath(working_dir)
     self.logger = Logger(os.path.join(self.working_dir, "easyfuse_processing.log"))
     IOMethods.create_folder(self.working_dir, self.logger)
     copy(os.path.join(cfg.module_dir, "config.py"), working_dir)
     self.logger.info("Starting easyfuse: CMD - {}".format(cmd))
     self.input_paths = [os.path.abspath(file) for file in input_paths]
     self.samples = SamplesDB(os.path.join(self.working_dir, "samples.db"))
示例#6
0
 def __init__(self, scratch_path, fetchdata_path, sample_id):
     """Parameter initiation and work folder creation."""
     self.scratch_path = scratch_path
     self.fetchdata_path = fetchdata_path
     self.sample_id = sample_id
     #self.tools = Samples(os.path.join(scratch_path, os.path.pardir, os.path.pardir, "samples.csv")).get_tool_list_from_state(self.sample_id)
     self.samples = SamplesDB(
         os.path.join(scratch_path, os.path.pardir, "samples.db"))
     self.logger = Logger(os.path.join(self.fetchdata_path,
                                       "fetchdata.log"))
示例#7
0
 def __init__(self):
     self.__redis_link = self.__redis_connect()
     mylog = Logger(
         os.path.join(os.path.abspath(os.path.curdir),
                      'misc/spider_log.yaml'))
     self.__logger = mylog.outputLog()
     self.mysql_connect = mysql_connect.MysqlConnect(
         os.path.join(os.path.abspath(os.path.curdir),
                      'misc/mysql_data.yaml'))
     self.main_url = 'http://www.hanfan.cc/'
 def __init__(self, scratch_path, fusion_output_path, sample_id,
              tool_num_cutoff, fusiontool_list, sample_log):
     """Parameter initiation and work folder creation."""
     self.scratch_path = scratch_path
     self.fusion_output_path = fusion_output_path
     self.sample_id = sample_id
     self.tool_num_cutoff = int(tool_num_cutoff)
     # urla: if we want to be more generic and allow different annotations, identification of the chr names
     #       (eg "chr1" vs "1" and "chrM" vs "MT") should be performed in advance
     self.chr_list = ("1", "2", "3", "4", "5", "6", "7", "8", "9", "10",
                      "11", "12", "13", "14", "15", "16", "17", "18", "19",
                      "20", "21", "22", "X", "Y", "MT")
     self.tools = fusiontool_list.split(",")
     self.logger = Logger(sample_log)
示例#9
0
文件: main.py 项目: DorisxinDU/mine-1
parser = argparse.ArgumentParser(description='GAN without MI')
parser.add_argument('--config', type=str, default='./configs/spiral_mine.yml',
                        help = 'Path to config file')
opts = parser.parse_args()
params = get_config(opts.config)
print(params)

train_loader, val_loader = spiral_dataloader(params)

if params['use_mine']:
    model = GAN_MI(params)
else:
    model = GAN(params)

if params['use_cuda']:
    model = model.cuda()

logger = Logger(params['logs'])

exp_logs = params['logs'] + params['exp_name'] + '_' + timestamp + '/' 
exp_results = params['results'] + params['exp_name'] + '_' + timestamp + '/'
mkdir_p(exp_logs)
mkdir_p(exp_results)

if params['use_mine']:
    gan_trainer = GANTrainerMI(model, params, train_loader, val_loader, logger, exp_results, exp_logs)
else:
    gan_trainer = GANTrainerVanilla(model, params, train_loader, val_loader, logger, exp_results, exp_logs)

gan_trainer.train()
示例#10
0
    model = DeepCross(opt=opt)
    model = model.cuda()

    if opt.loader:
        print("load checkpoint file .")
        model.load_state_dict(
            torch.load(os.path.join('models', 'model-1.ckpt')))

    current_lr = 1e-3
    optimizer = optim.Adam(model.parameters(), lr=current_lr)

    # criterion = nn.BCEWithLogitsLoss()
    criterion = FocalLoss()
    # criterion = nn.BCELoss()
    logger = Logger('./logs/')

    for epoch in range(2, opt.num_epoches):
        # schedule learning rate
        frac = epoch // 2
        decay_factor = 0.9**frac
        current_lr = current_lr * decay_factor
        utils.set_lr(optimizer, current_lr)

        # training
        model.train()
        start = time.time()

        for i, data in enumerate(train_loader):
            # prepare data and corresponding label(which is 'click')
            user_id = data['user_id'].cuda()
示例#11
0
def main():
    mylog = Logger(os.path.join(os.path.abspath(os.path.curdir),'misc/spider_log.yaml'))
    logger = mylog.outputLog()
    items = spiderman()
    for item in items:
        send_mysql(item, logger)
示例#12
0
    def __init__(self, opt):
        lopt = opt.logger
        topt = opt.trainer
        mopt = opt.model
        gopt = opt.model.gen
        copt = opt.model.crit
        goopt = opt.optim.gen
        coopt = opt.optim.crit

        #CUDA configuration
        if opt.device == 'cuda' and torch.cuda.is_available():
            os.environ['CUDA_VISIBLE_DEVICES'] = opt.deviceId
            torch.backends.cudnn.benchmark = True
        else:
            opt.device = 'cpu'

        self.device = torch.device(opt.device)

        #logger
        self.logger_ = Logger(self, gopt.latentSize, topt.resumeTraining,
                              opt.tick, opt.loops, lopt.logPath, lopt.logStep,
                              lopt.saveImageEvery, lopt.saveModelEvery,
                              lopt.logLevel, self.device)
        self.logger = self.logger_.logger

        #Logging configuration parameters
        if opt.device == 'cuda':
            num_gpus = len(opt.deviceId.split(','))
            self.logger.info("Using {} GPUs.".format(num_gpus))
            self.logger.info("Training on {}.\n".format(
                torch.cuda.get_device_name(0)))

        #data loader
        dlopt = opt.dataLoader

        self.dataLoader = DataLoader(dlopt.dataPath, dlopt.resolution,
                                     dlopt.noChannels, dlopt.batchSize,
                                     dlopt.numWorkers)

        self.resolution, self.nCh = self.dataLoader.resolution, self.dataLoader.nCh

        # training opt
        assert opt.tick > 0, self.logger.error(
            f'The number of ticks should be a positive integer, got {opt.tick} instead'
        )
        self.tick = float(opt.tick)

        assert opt.loops > 0, self.logger.error(
            f'The number of ticks should be a positive integer, got {opt.loops} instead'
        )
        self.loops = int(opt.loops)

        self.imShown = 0
        self.batchShown = self.imShown // self.dataLoader.batchSize

        assert topt.lossFunc in ['NSL', 'WD'], self.logger.error(
            f'The specified loss model is not supported. Please choose between "NSL" or "WD"'
        )
        self.lossFunc = topt.lossFunc
        self.criterion = NonSaturatingLoss if self.lossFunc == 'NSL' else WassersteinLoss

        self.applyLossScaling = bool(topt.applyLossScaling)

        self.paterm = topt.paterm
        self.lambg = float(topt.lambg)
        self.gLazyReg = max(topt.gLazyReg, 1)
        self.styleMixingProb = float(topt.styleMixingProb)

        self.meanPathLength = 0.

        self.plDecay = topt.meanPathLengthDecay

        self.pathRegWeight = topt.pathLengthRWeight

        assert topt.nCritPerGen > 0, self.logger.error(
            f'Trainer ERROR: The number of critic training loops per generator loop should be an integer >= 1 (got {topt.nCritPerGen})'
        )
        self.nCritPerGen = int(topt.nCritPerGen)

        self.lambR2 = float(topt.lambR2) if topt.lambR2 else 0  #lambda R2
        self.obj = float(topt.obj) if topt.obj else 1  #objective value (1-GP)

        self.lambR1 = float(topt.lambR1) if topt.lambR2 else 0  #lambda R1

        self.epsilon = float(
            topt.epsilon) if topt.epsilon else 0  #epsilon (drift loss)

        self.cLazyReg = max(topt.cLazyReg, 1)

        self.kUnroll = int(topt.unrollCritic) if topt.unrollCritic else 0

        assert self.kUnroll >= 0, self.logger.error(
            f'Trainer ERROR: The unroll parameter is less than zero ({self.kUnroll})'
        )

        #Common model parameters
        common = {
            'fmapMax': mopt.fmapMax,
            'fmapMin': mopt.fmapMin,
            'fmapDecay': mopt.fmapDecay,
            'fmapBase': mopt.fmapBase,
            'activation': mopt.activation,
            'upsample': mopt.sampleMode,
            'downsample': mopt.sampleMode
        }

        #Generator model parameters
        self.gen = Generator(**common, **gopt).to(self.device)
        self.latentSize = self.gen.mapping.latentSize

        self.logger.info(
            f'Generator constructed. Number of parameters {sum([np.prod([*p.size()]) for p in self.gen.parameters()])}'
        )

        #Critic model parameters
        self.crit = Critic(**mopt, **copt).to(self.device)

        self.logger.info(
            f'Critic constructed. Number of parameters {sum([np.prod([*p.size()]) for p in self.crit.parameters()])}'
        )

        #Generator optimizer parameters
        glr, beta1, beta2, epsilon, lrDecay, lrDecayEvery, lrWDecay = list(
            goopt.values())

        assert lrDecay >= 0 and lrDecay <= 1, self.logger.error(
            'Trainer ERROR: The decay constant for the learning rate of the generator must be a constant between [0, 1]'
        )
        assert lrWDecay >= 0 and lrWDecay <= 1, self.logger.error(
            'Trainer ERROR: The weight decay constant for the generator must be a constant between [0, 1]'
        )
        self.gOptimizer = Adam(filter(lambda p: p.requires_grad,
                                      self.gen.parameters()),
                               lr=glr,
                               betas=(beta1, beta2),
                               weight_decay=lrWDecay,
                               eps=epsilon)

        if lrDecayEvery and lrDecay:
            self.glrScheduler = lr_scheduler.StepLR(self.gOptimizer,
                                                    step_size=lrDecayEvery *
                                                    self.tick,
                                                    gamma=lrDecay)
        else:
            self.glrScheduler = None

        self.logger.info(f'Generator optimizer constructed')

        #Critic optimizer parameters
        clr, beta1, beta2, epsilon, lrDecay, lrDecayEvery, lrWDecay = list(
            coopt.values())

        assert lrDecay >= 0 and lrDecay <= 1, self.logger.error(
            'Trainer ERROR: The decay constant for the learning rate of the critic must be a constant between [0, 1]'
        )
        assert lrWDecay >= 0 and lrWDecay <= 1, self.logger.error(
            'Trainer ERROR: The weight decay constant for the critic must be a constant between [0, 1]'
        )

        self.cOptimizer = Adam(filter(lambda p: p.requires_grad,
                                      self.crit.parameters()),
                               lr=clr,
                               betas=(beta1, beta2),
                               weight_decay=lrWDecay,
                               eps=epsilon)

        if lrDecayEvery and lrDecay:
            self.clrScheduler = lr_scheduler.StepLR(self.gOptimizer,
                                                    step_size=lrDecayEvery *
                                                    self.tick,
                                                    gamma=lrDecay)
        else:
            self.clrScheduler = None

        self.logger.info(f'Critic optimizer constructed')

        self.preWtsFile = opt.preWtsFile
        self.resumeTraining = bool(topt.resumeTraining)
        self.loadPretrainedWts(resumeTraining=self.resumeTraining)

        self.logger.info(
            f'The trainer has been instantiated.... Starting step: {self.imShown}. Resolution: {self.resolution}'
        )

        self.logArchitecture(clr, glr)
def eval_patch_shuffle(model,
                       dataset_builder,
                       max_num_devide: int,
                       num_samples: int,
                       batch_size: int,
                       num_workers: int,
                       top_k: int,
                       log_dir: str,
                       log_params: dict = {},
                       suffix: str = '',
                       shuffle: bool = False,
                       **kwargs):
    """
    Args
    - model: NN model
    - dataset_builder: DatasetBuilder class object
    - max_num_devide: max number of division
    - num_samples: number of sample to use. if -1, all samples are used
    - batch_size: size of batch
    - num_workers: number of workers
    - top_k: use top_k accuracy
    - log_dir: log directory
    - log_params: params which is logged in dataframe. these params are useful for legend.
    - suffix: suffix of log
    - shuffle: shuffle data
    """
    assert max_num_devide >= 1
    assert num_samples >= 1 or num_samples == -1
    assert batch_size >= 1
    assert num_workers >= 1
    assert top_k >= 1

    log_path = os.path.join(
        log_dir, os.path.join('pathch_shuffle_result' + suffix + '.csv'))
    logger = Logger(path=log_path, mode='test')

    # log params
    # logger.log(log_params)

    acc_dict = {}
    images_list = []

    for num_devide in tqdm.tqdm(range(1, max_num_devide + 1)):
        log_dict = collections.OrderedDict()

        # build Patch Shuffled dataset
        patch_shuffle_transform = PatchShuffle(num_devide, num_devide)
        dataset = dataset_builder(train=False,
                                  normalize=True,
                                  optional_transform=[patch_shuffle_transform])
        if num_samples != -1:
            num_samples = min(num_samples, len(dataset))
            indices = [i for i in range(num_samples)]
            dataset = torch.utils.data.Subset(dataset, indices)
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=shuffle,
                                             num_workers=num_workers,
                                             pin_memory=True)

        with torch.autograd.no_grad():
            num_correct = 0.0
            for i, (x, t) in enumerate(loader):
                model.eval()
                x = x.to('cuda', non_blocking=True)
                t = t.to('cuda', non_blocking=True)

                model.zero_grad()
                logit = model(x)
                num_correct += get_num_correct(logit, t, topk=top_k)

                if i == 0:
                    images_list.append(x[10])

        acc = num_correct / float(len(dataset))
        key = '{num_devide}'.format(num_devide=num_devide)
        acc_dict[key] = acc

        log_dict['num_devide'] = num_devide
        log_dict['accuracy'] = acc
        logger.log(log_dict)
        print(acc_dict)

    # save data
    torch.save(
        acc_dict,
        os.path.join(log_dir, 'patch_shuffle_acc_dict' + suffix + '.pth'))
    torchvision.utils.save_image(torch.stack(images_list, dim=0),
                                 os.path.join(
                                     log_dir,
                                     'example_images' + suffix + '.png'),
                                 nrow=max_num_devide)
    plot(csv_path=log_path,
         x='num_devide',
         y='accuracy',
         hue=None,
         log_path=os.path.join(log_dir, 'plot.png'),
         save=True)
示例#14
0
def train(model, vocab, cfg):
    seqtree_coco = SeqtreeCOCO()
    loader = DataLoader(seqtree_coco,
                        batch_size=16,
                        shuffle=True,
                        num_workers=4)
    logdir = os.path.join(cfg.checkpoint_path, cfg.id)
    if not os.path.isdir(logdir):
        os.mkdir(logdir)
    logger = Logger(logdir)

    with open(os.path.join(logdir, 'config.txt'), 'w') as f:
        f.write(str(cfg))
    with open('data/idx2caps.json', 'r') as f:
        cocoid2caps = json.load(f)
    cocoid2caps = {int(k): v for k, v in cocoid2caps.items()}
    init_scorer('coco-train-idxs')

    infos = {}
    # if cfg.start_from is not None:
    #     with open(os.path.join(cfg.start_from, 'infos_' + cfg.start_from + '_best.pkl'), 'rb') as f:
    #         infos = pickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})
    loss_history = infos.get('loss_history', {})
    lr_history = infos.get('lr_history', {})

    best_val_score = 0

    update_lr_flag = True
    if cfg.caption_model == 'att_model' or cfg.caption_model == 'tree_model' \
            or cfg.caption_model == 'tree_model_1' or cfg.caption_model == 'tree_model_md' \
            or cfg.caption_model == 'tree_model_2' or cfg.caption_model == 'tree_model_md_att' \
            or cfg.caption_model == 'tree_model_md_sob' or cfg.caption_model == 'tree_model_md_in' \
            or cfg.caption_model == 'drnn':
        # crit = nn.CrossEntropyLoss()
        crit = LanguageModelCriterion()
        rl_crit = RewardCriterion()
    else:
        raise Exception("Caption model not supported: {}".format(
            cfg.caption_model))

    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.learning_rate)

    num_period_best = 0
    current_score = 0
    start = time.time()

    print("start training...")

    while True:
        if update_lr_flag:
            if epoch > cfg.learning_rate_decay_start >= 0:
                frac = (epoch - cfg.learning_rate_decay_start
                        ) // cfg.learning_rate_decay_every
                decay_factor = cfg.learning_rate_decay_rate**frac
                cfg.current_lr = cfg.learning_rate * decay_factor
                utils.set_lr(optimizer, cfg.current_lr)
            else:
                cfg.current_lr = cfg.learning_rate

        optimizer.zero_grad()
        for data in loader:
            if cfg.use_cuda:
                torch.cuda.synchronize()

            if cfg.caption_model == 'tree_model_md_att':
                temp = [
                    data['word_idx'], data['father_idx'], data['masks'],
                    data['fc_feats'], data['att_feats']
                ]
                temp = [_.cuda() for _ in temp]
                word_idx, father_idx, masks, fc_feats, att_feats = temp

            elif cfg.caption_model == 'tree_model_md' or cfg.caption_model == 'tree_model_md_sob' \
                    or cfg.caption_model == 'tree_model_md_in' or cfg.caption_model == 'drnn':
                temp = [
                    data['word_idx'], data['father_idx'], data['masks'],
                    data['fc_feats']
                ]
                temp = [_.cuda() for _ in temp]
                word_idx, father_idx, masks, fc_feats = temp
                # words = [[vocab.idx2word[word_idx[batch_index][i].item()] for i in range(40)]
                #          for batch_index in range(2)]

            else:
                raise Exception("Caption model not supported: {}".format(
                    cfg.caption_model))

            optimizer.zero_grad()
            # if cfg.caption_model == 'tree_model_md_att':
            #     logprobs = model(word_idx, father_idx, fc_feats, att_feats)
            #     loss = crit(logprobs, word_idx, masks)
            if cfg.caption_model == 'tree_model_md' or cfg.caption_model == 'tree_model_md_sob' \
                    or cfg.caption_model == 'tree_model_md_in' or cfg.caption_model == 'drnn' \
                    or cfg.caption_model == 'tree_model_md_att':
                word_idx, father_idx, mask, seqLogprobs = model._sample(
                    fc_feats, att_feats, max_seq_length=40)
                gen_result = utils.decode_sequence(vocab, word_idx, father_idx,
                                                   mask)
                ratio = utils.seq2ratio(word_idx, father_idx, mask)
                reward = get_self_critical_reward(model, fc_feats, att_feats,
                                                  data, gen_result,
                                                  vocab, cocoid2caps,
                                                  word_idx.size(1), cfg)
                loss = rl_crit(seqLogprobs, mask,
                               torch.from_numpy(reward).float().cuda(), ratio)

            else:
                raise Exception("Caption model not supported: {}".format(
                    cfg.caption_model))

            loss.backward()
            utils.clip_gradient(optimizer, cfg.grad_clip)
            optimizer.step()
            train_loss = loss.item()

            if cfg.use_cuda:
                torch.cuda.synchronize()

            if iteration % cfg.losses_log_every == 0:
                end = time.time()
                logger.scalar_summary('train_loss', train_loss, iteration)
                logger.scalar_summary('learning_rate', cfg.learning_rate,
                                      iteration)
                loss_history[iteration] = train_loss
                lr_history[iteration] = cfg.current_lr
                print(
                    "iter {} (epoch {}), learning_rate: {:.6f}, train_loss: {:.6f}, current_cider: {:.3f}, best_cider: {:.3f}, time/log = {:.3f}" \
                        .format(iteration, epoch, cfg.current_lr, train_loss, current_score, best_val_score,
                                end - start))
                start = time.time()

            if (iteration + 1) % cfg.save_checkpoint_every == 0:
                eval_kwargs = {'eval_split': 'val', 'eval_time': False}
                eval_kwargs.update(vars(cfg))
                # lang_stats = eval_utils.eval_split(model, vocab, eval_kwargs)
                lang_stats = eval_seqtree.eval_split(model, vocab, eval_kwargs)
                if cfg.use_cuda:
                    model = model.cuda()

                for k, v in lang_stats.items():
                    logger.scalar_summary(k, v, iteration)

                val_result_history[iteration] = {'lang_stats': lang_stats}

                current_score = lang_stats['CIDEr']
                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                    num_period_best = 1
                else:
                    num_period_best += 1

                if best_flag:
                    infos['iter'] = iteration
                    infos['epoch'] = epoch
                    infos['val_result_history'] = val_result_history
                    infos['loss_history'] = loss_history
                    infos['lr_history'] = lr_history

                    checkpoint_path = os.path.join(
                        logdir, 'model_' + cfg.id + '_best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    optimizer_path = os.path.join(
                        logdir, 'optimizer_' + cfg.id + '_best.pth')
                    torch.save(optimizer.state_dict(), optimizer_path)
                    print("model saved to {}".format(logdir))
                    with open(
                            os.path.join(logdir,
                                         'infos_' + cfg.id + '_best.pkl'),
                            'wb') as f:
                        pickle.dump(infos, f)

                if num_period_best >= cfg.num_eval_no_improve:
                    print('no improvement, exit({})'.format(best_val_score))
                    sys.exit()

            iteration += 1

        epoch += 1
        if epoch >= cfg.max_epoches != -1:
            break
示例#15
0
        '%d_%m_%Y_%H_%M_%S')

    parser = argparse.ArgumentParser(description='mine')
    parser.add_argument('--config',
                        type=str,
                        default='./configs/mine.yml',
                        help='Path to config file')
    opts = parser.parse_args()
    params = get_config(opts.config)
    print(params)

    model = Mine(params)
    if params['use_cuda']:
        model = model.cuda()

    if params['training'] == True and params['visualize'] == False:
        exp_logs = params['logs'] + params['exp_name'] + '_' + timestamp + '/'
        exp_results = params['results'] + params[
            'exp_name'] + '_' + timestamp + '/'
        mkdir_p(exp_logs)
        mkdir_p(exp_results)

        config_logfile = exp_logs + 'config.json'
        with open(config_logfile, 'w+') as cf:
            json.dump(params, cf)

        optimizer = optim.Adam(model.parameters(), lr=params['lr'])
        logger = Logger(exp_logs)

        train(params)
示例#16
0
parser.add_argument(
    '--config',
    type=str,
    default=
    '/home/rudra/Downloads/rudra/relationship_modeling/o2p2/physics_engine/configs/pre-planning.yml',
    help='Path to config file')
opts = parser.parse_args()
params = get_config(opts.config)
pp = pprint.PrettyPrinter(indent=2)
pp.pprint(params)

# Define models and dataloaders
train_loader, val_loader = initial_final_dataloader(params)
model = O2P2Model(params)

if params['use_cuda']:
    model = model.cuda()

exp_results_path = params['project_root'] + '/results/' + params[
    'exp_name'] + '_' + timestamp + '/'
exp_logs_path = params['project_root'] + '/logs/' + params[
    'exp_name'] + '_' + timestamp + '/'
mkdir_p(exp_logs_path)
mkdir_p(exp_results_path)

logger = Logger(exp_logs_path)

trainer = O2P2Trainer(params, model, train_loader, val_loader, logger,
                      exp_results_path, exp_logs_path)

trainer.train()