Exemplo n.º 1
0
def main(args):
    # Set logging
    if not os.path.exists("./log"):
        os.makedirs("./log")

    log = set_log(args)
    tb_writer = SummaryWriter('./log/tb_{0}'.format(args.log_name))

    # Set seed
    set_seed(args.seed, cudnn=args.make_deterministic)

    # Set sampler
    sampler = BatchSampler(args, log)

    # Set policy
    policy = CaviaMLPPolicy(
        input_size=int(np.prod(sampler.observation_space.shape)),
        output_size=int(np.prod(sampler.action_space.shape)),
        hidden_sizes=(args.hidden_size, ) * args.num_layers,
        num_context_params=args.num_context_params,
        device=args.device)

    # Initialise baseline
    baseline = LinearFeatureBaseline(
        int(np.prod(sampler.observation_space.shape)))

    # Initialise meta-learner
    metalearner = MetaLearner(sampler, policy, baseline, args, tb_writer)

    # Begin train
    train(sampler, metalearner, args, log, tb_writer)
Exemplo n.º 2
0
def main(opt):
    if opt.get('seed', -1) == -1:
        opt['seed'] = random.randint(1, 65534)
    set_seed(opt['seed'])

    # log files and the best model will be saved at 'checkpoint_path'
    opt["checkpoint_path"] = where_to_save_model(opt)
    if not os.path.exists(opt["checkpoint_path"]):
        os.makedirs(opt["checkpoint_path"])

    # get full paths to load features / corpora
    for key in ['feats_a_name', 'feats_m_name', 'feats_i_name', 'feats_o_name', 'feats_t_name'] \
        + ['reference_name', 'info_corpus_name']:
        opt[key[:-5]] = get_dir(opt, key, 'feats' if 'feats' in key else '')
        opt.pop(key)

    # the assignment of 'vocab_size' should be done before defining the model
    opt['vocab_size'] = len(
        pickle.load(open(opt['info_corpus'], 'rb'))['info']['itow'].keys())

    # save training settings
    opt_json = os.path.join(opt["checkpoint_path"], 'opt_info.json')
    with open(opt_json, 'w') as f:
        json.dump(opt, f)
    print('save opt details to %s' % (opt_json))

    model = get_model(opt)
    print_information(opt, model)
    device = torch.device('cuda' if not opt['no_cuda'] else 'cpu')

    if opt.get('pretrained_path', ''):
        print('loading pretrained model from %s' % opt['pretrained_path'])
        model.load_state_dict(torch.load(opt['pretrained_path'])['state_dict'])

    train_network_all(opt, model, device)
Exemplo n.º 3
0
def run_experiment(spec,
                   monitor_path=None,
                   only_last=False,
                   description=None,
                   seed=None):
    """Run an experiment using a specification dictionary."""

    import os

    if seed is not None:
        set_seed(seed)

    import datetime
    import gym
    gym.logger.set_level(gym.logger.ERROR)
    from gym.spaces import Discrete
    from environment.registration import make, make_environments
    from agents.registration import make_agent

    args = spec["agent"]["args"]
    args["config_path"] = os.path.join(monitor_path, "config.json")
    if monitor_path:
        args["monitor_path"] = monitor_path
    else:
        monitor_path = args["monitor_path"]
    if not os.path.exists(monitor_path):
        os.makedirs(monitor_path)
    envs_type = spec["environments"]["type"]
    if envs_type == "single":
        envs = [make(spec["environments"]["source"])]
    elif envs_type == "json":
        envs = make_environments(json_to_dict(spec["environments"]["source"]))
    if seed is not None:
        for env in envs:
            env.seed(seed)
    args["seed"] = seed
    args["envs"] = envs
    if len(envs) == 1 or only_last:
        args["env"] = envs[-1]
    action_space_type = "discrete" if isinstance(envs[0].action_space,
                                                 Discrete) else "continuous"
    state_dimensions = "single" if len(
        envs[0].observation_space.shape) == 1 else "multi"
    agent = make_agent(spec["agent"]["name"], state_dimensions,
                       action_space_type, **args)
    config = agent.config.copy()
    if description is not None:
        config["description"] = description
    config["seed"] = str(seed)
    config["start_time"] = datetime.datetime.now().astimezone().isoformat()
    save_config(monitor_path, config,
                [env.metadata["parameters"] for env in envs])
    agent.learn()
Exemplo n.º 4
0
def main(opt):
    if opt.get('seed', -1) == -1:
        opt['seed'] = random.randint(1, 65534)
    utils.set_seed(opt['seed'])
    print('SEEEEEEEEEEED: %d' % opt['seed'])

    model_name = opt['encoder_type'] + '_' + opt['decoder_type']
    modality = opt['modality'].upper(
    )  #(opt['modality'].upper() + 's') if 'c3d' in opt['feats_m_name'][0] else opt['modality'].upper()

    if opt['na'] or opt['ar']:
        scope = get_scope2(opt)
    else:
        scope = get_scope(opt)

    opt["checkpoint_path"] = os.path.join(opt["checkpoint_path"],
                                          opt['checkpoint_path_name'],
                                          opt['dataset'], model_name, scope)

    opt['feats_a'] = get_dir(opt, 'feats_a_name', '/feats/')
    opt['feats_m'] = get_dir(opt, 'feats_m_name', '/feats/', pre=True)
    opt['feats_i'] = get_dir(opt, 'feats_i_name', '/feats/', pre=True)
    opt['feats_s'] = get_dir(opt, 'feats_s_name', '/feats/', pre=True)
    opt['feats_t'] = get_dir(opt, 'feats_t_name', '/feats/')  #, pre=True)
    '''
    opt['info_json'] = get_dir(opt, 'info_json_name', post='_%d.json' % opt['word_count_threshold'], prefix=opt['prefix'])
    opt['caption_json'] = get_dir(opt, 'caption_json_name', post='_%d.json' % opt['word_count_threshold'], prefix=opt['prefix'])
    opt['next_info_json'] = get_dir(opt, 'next_info_json_name', post='_%d.json' % opt['word_count_threshold'], prefix=opt['prefix'])
    opt['all_caption_json'] = get_dir(opt, 'all_caption_json_name', post='_%d.json' % opt['word_count_threshold'], prefix=opt['prefix'])
    opt['input_json'] = get_dir(opt, 'input_json_name', post='.json')
    '''
    opt['reference'] = get_dir(opt, 'reference_name', post='.pkl', pre=True)

    if opt.get('knowledge_distillation_with_bert', False):
        opt['bert_embeddings'] = get_dir(opt,
                                         'bert_embeddings_name',
                                         '/feats/',
                                         pre=True)

    opt['info_corpus'] = get_dir(opt,
                                 'info_corpus_name',
                                 post='_%d%s.pkl' %
                                 (opt['word_count_threshold'],
                                  '_%d' % opt['dist'] if opt['dist'] else ''),
                                 prefix=opt['prefix'])

    opt['corpus_pickle'] = get_dir(opt, 'corpus_name', post='.pkl')

    opt['vocab_size'] = len(
        pickle.load(
            open(opt['corpus_pickle'] if opt['others'] else opt['info_corpus'],
                 'rb'))['info']['itow'].keys())
    #opt['tag_size'] = len(json.load(open(opt["info_json"]))['ix_to_tag'].keys())

    opt_json = os.path.join(opt["checkpoint_path"], 'opt_info.json')
    if not os.path.exists(opt["checkpoint_path"]):
        os.makedirs(opt["checkpoint_path"])
    with open(opt_json, 'w') as f:
        json.dump(opt, f)
    print('save opt details to %s' % (opt_json))

    model = get_model(opt)
    device = torch.device('cuda' if not opt['no_cuda'] else 'cpu')
    '''517 yb'''
    if opt.get('use_beam_decoder', False):
        assert opt['load_pretrained']
        checkpoint = torch.load(opt['load_pretrained'])['state_dict']

        # make sure that current network is the same as the pretrained model
        #namelist = [item for item, _ in model.named_parameters()]
        #print(namelist)
        #for k in checkpoint.keys():
        #    if 'bn' in k:
        #        continue
        #    print(k)
        #    assert k in namelist
        model.load_state_dict(checkpoint, strict=False)

        # we only train beam decoder
        for name, parameter in model.named_parameters():
            if 'beam' not in name:
                parameter.requires_grad = False

        print_information(opt, model, model_name)
        train_beam_decoder(
            opt,
            model,
            device,
            first_evaluate_whole_folder=opt['first_evaluate_whole_folder'])
    else:
        if opt['load_pretrained']:
            model.load_state_dict(
                torch.load(opt['load_pretrained'])['state_dict'])

        print_information(opt, model, model_name)
        train_network_all(
            opt,
            model,
            device,
            first_evaluate_whole_folder=opt['first_evaluate_whole_folder'])
Exemplo n.º 5
0
def train(opt):
    set_seed(opt.seed)
    save_folder = build_floder(opt)
    logger = create_logger(save_folder, 'train.log')
    tf_writer = SummaryWriter(os.path.join(save_folder, 'tf_summary'))

    if not opt.start_from:
        backup_envir(save_folder)
        logger.info('backup evironment completed !')

    saved_info = {'best': {}, 'last': {}, 'history': {}, 'eval_history': {}}

    # continue training
    if opt.start_from:
        opt.pretrain = False
        infos_path = os.path.join(save_folder, 'info.json')
        with open(infos_path) as f:
            logger.info('Load info from {}'.format(infos_path))
            saved_info = json.load(f)
            prev_opt = saved_info[opt.start_from_mode[:4]]['opt']

            exclude_opt = ['start_from', 'start_from_mode', 'pretrain']
            for opt_name in prev_opt.keys():
                if opt_name not in exclude_opt:
                    vars(opt).update({opt_name: prev_opt.get(opt_name)})
                if prev_opt.get(opt_name) != vars(opt).get(opt_name):
                    logger.info('Change opt {} : {} --> {}'.format(
                        opt_name, prev_opt.get(opt_name),
                        vars(opt).get(opt_name)))
        opt.feature_dim = opt.raw_feature_dim

    train_dataset = PropSeqDataset(opt.train_caption_file,
                                   opt.visual_feature_folder, True,
                                   opt.train_proposal_type, logger, opt)

    val_dataset = PropSeqDataset(opt.val_caption_file,
                                 opt.visual_feature_folder, False, 'gt',
                                 logger, opt)

    train_loader = DataLoader(train_dataset,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=opt.nthreads,
                              collate_fn=collate_fn)

    val_loader = DataLoader(val_dataset,
                            batch_size=opt.batch_size,
                            shuffle=False,
                            num_workers=opt.nthreads,
                            collate_fn=collate_fn)

    epoch = saved_info[opt.start_from_mode[:4]].get('epoch', 0)
    iteration = saved_info[opt.start_from_mode[:4]].get('iter', 0)
    best_val_score = saved_info[opt.start_from_mode[:4]].get(
        'best_val_score', -1e5)
    val_result_history = saved_info['history'].get('val_result_history', {})
    loss_history = saved_info['history'].get('loss_history', {})
    lr_history = saved_info['history'].get('lr_history', {})
    opt.current_lr = vars(opt).get('current_lr', opt.lr)

    # Build model
    model = EncoderDecoder(opt)
    model.train()

    # Recover the parameters
    if opt.start_from and (not opt.pretrain):
        if opt.start_from_mode == 'best':
            model_pth = torch.load(
                os.path.join(save_folder, 'model-best-CE.pth'))
        elif opt.start_from_mode == 'last':
            model_pth = torch.load(os.path.join(save_folder, 'model-last.pth'))
        logger.info('Loading pth from {}, iteration:{}'.format(
            save_folder, iteration))
        model.load_state_dict(model_pth['model'])

    # Load the pre-trained model
    if opt.pretrain and (not opt.start_from):
        logger.info('Load pre-trained parameters from {}'.format(
            opt.pretrain_path))
        if torch.cuda.is_available():
            model_pth = torch.load(opt.pretrain_path)
        else:
            model_pth = torch.load(opt.pretrain_path,
                                   map_location=torch.device('cpu'))
        model.load_state_dict(model_pth['model'])

    if torch.cuda.is_available():
        model.cuda()

    if opt.optimizer_type == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=opt.lr,
                               weight_decay=opt.weight_decay)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.lr,
                              weight_decay=opt.weight_decay)

    if opt.start_from:
        optimizer.load_state_dict(model_pth['optimizer'])

    # print the args for debugging
    print_opt(opt, model, logger)
    print_alert_message('Strat training !', logger)

    loss_sum = np.zeros(3)
    bad_video_num = 0
    start = time.time()

    # Epoch-level iteration
    while True:
        if True:
            # lr decay
            if epoch > opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.lr * decay_factor
            else:
                opt.current_lr = opt.lr
            utils.set_lr(optimizer, opt.current_lr)

            # scheduled sampling rate update
            if epoch > opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(
                    opt.basic_ss_prob +
                    opt.scheduled_sampling_increase_prob * frac,
                    opt.scheduled_sampling_max_prob)
                model.decoder.ss_prob = opt.ss_prob

        # Batch-level iteration
        for dt in tqdm(train_loader):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            if opt.debug:
                # each epoch contains less mini-batches for debugging
                if (iteration + 1) % 5 == 0:
                    iteration += 1
                    break
            elif epoch == 0:
                break
            iteration += 1

            if torch.cuda.is_available():
                optimizer.zero_grad()
                dt = {
                    key: _.cuda() if isinstance(_, torch.Tensor) else _
                    for key, _ in dt.items()
                }

            dt = collections.defaultdict(lambda: None, dt)

            if True:
                train_mode = 'train'

                loss = model(dt, mode=train_mode)
                loss_sum[0] = loss_sum[0] + loss.item()

                loss.backward()
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
                if torch.cuda.is_available():
                    torch.cuda.synchronize()

            losses_log_every = int(len(train_loader) / 5)

            if iteration % losses_log_every == 0:
                end = time.time()
                losses = np.round(loss_sum / losses_log_every, 3)
                logger.info(
                    "ID {} iter {} (epoch {}, lr {}), avg_iter_loss = {}, time/iter = {:.3f}, bad_vid = {:.3f}"
                    .format(opt.id, iteration, epoch, opt.current_lr, losses,
                            (end - start) / losses_log_every, bad_video_num))

                tf_writer.add_scalar('lr', opt.current_lr, iteration)
                tf_writer.add_scalar('ss_prob', model.decoder.ss_prob,
                                     iteration)
                tf_writer.add_scalar('train_caption_loss', losses[0].item(),
                                     iteration)

                loss_history[iteration] = losses.tolist()
                lr_history[iteration] = opt.current_lr
                loss_sum = 0 * loss_sum
                start = time.time()
                bad_video_num = 0
                torch.cuda.empty_cache()

        # evaluation
        if (epoch % opt.save_checkpoint_every
                == 0) and (epoch >= opt.min_epoch_when_save) and (epoch != 0):
            model.eval()

            result_json_path = os.path.join(
                save_folder, 'prediction',
                'num{}_epoch{}_score{}_nms{}_top{}.json'.format(
                    len(val_dataset), epoch, opt.eval_score_threshold,
                    opt.eval_nms_threshold, opt.eval_top_n))
            eval_score = evaluate(model,
                                  val_loader,
                                  result_json_path,
                                  opt.eval_score_threshold,
                                  opt.eval_nms_threshold,
                                  opt.eval_top_n,
                                  False,
                                  1,
                                  logger=logger)
            current_score = np.array(eval_score['f1']).mean()

            # add to tf summary
            for key in eval_score.keys():
                tf_writer.add_scalar(key,
                                     np.array(eval_score[key]).mean(),
                                     iteration)
            _ = [
                item.append(np.array(item).mean())
                for item in eval_score.values() if isinstance(item, list)
            ]
            print_info = '\n'.join([
                key + ":" + str(eval_score[key]) for key in eval_score.keys()
            ])
            logger.info(
                '\nValidation results of iter {}:\n'.format(iteration) +
                print_info)
            val_result_history[epoch] = {'eval_score': eval_score}

            # Save model
            saved_pth = {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }

            if opt.save_all_checkpoint:
                checkpoint_path = os.path.join(
                    save_folder, 'model_iter_{}.pth'.format(iteration))
            else:
                checkpoint_path = os.path.join(save_folder, 'model-last.pth')

            torch.save(saved_pth, checkpoint_path)
            logger.info('Save model at iter {} to {}.'.format(
                iteration, checkpoint_path))

            # save the model parameter and  of best epoch
            if current_score > best_val_score:
                best_val_score = current_score
                best_epoch = epoch
                saved_info['best'] = {
                    'opt': vars(opt),
                    'iter': iteration,
                    'epoch': best_epoch,
                    'best_val_score': best_val_score,
                    'result_json_path': result_json_path,
                    'avg_proposal_num': eval_score['avg_proposal_number'],
                    'Precision': eval_score['Precision'],
                    'Recall': eval_score['Recall']
                }

                # suffix = "RL" if sc_flag else "CE"
                torch.save(saved_pth,
                           os.path.join(save_folder, 'model-best.pth'))
                logger.info(
                    'Save Best-model at iter {} to checkpoint file.'.format(
                        iteration))

            saved_info['last'] = {
                'opt': vars(opt),
                'iter': iteration,
                'epoch': epoch,
                'best_val_score': best_val_score,
            }
            saved_info['history'] = {
                'val_result_history': val_result_history,
                'loss_history': loss_history,
                'lr_history': lr_history,
            }
            with open(os.path.join(save_folder, 'info.json'), 'w') as f:
                json.dump(saved_info, f)
            logger.info('Save info to info.json')

            model.train()

        epoch += 1
        torch.cuda.empty_cache()
        # Stop criterion
        if epoch >= opt.epoch:
            tf_writer.close()
            break

    return saved_info
Exemplo n.º 6
0
def train(opt):
    set_seed(opt.seed)
    save_folder = build_floder(opt)  # './save/debug_2020-10-26_08-53-55'  创建结果文件夹
    logger = create_logger(save_folder, 'train.log')   # 创建logger对象
    tf_writer = SummaryWriter(os.path.join(save_folder, 'tf_summary'))   # tensorboardX

    if not opt.start_from:
        backup_envir(save_folder)   # backup是备份的意思
        logger.info('backup evironment completed !')

    saved_info = {'best': {}, 'last': {}, 'history': {}, 'eval_history': {}}

    # continue training
    if opt.start_from:
        opt.pretrain = False
        infos_path = os.path.join(save_folder, 'info.json')
        with open(infos_path) as f:
            logger.info('Load info from {}'.format(infos_path))
            saved_info = json.load(f)
            prev_opt = saved_info[opt.start_from_mode[:4]]['opt']

            exclude_opt = ['start_from', 'start_from_mode', 'pretrain']
            for opt_name in prev_opt.keys():
                if opt_name not in exclude_opt:
                    vars(opt).update({opt_name: prev_opt.get(opt_name)})
                if prev_opt.get(opt_name) != vars(opt).get(opt_name):
                    logger.info('Change opt {} : {} --> {}'.format(opt_name, prev_opt.get(opt_name),
                                                                   vars(opt).get(opt_name)))
        opt.feature_dim = opt.raw_feature_dim

    train_dataset = PropSeqDataset(opt.train_caption_file,
                                   opt.visual_feature_folder,
                                   opt.dict_file, True, opt.train_proposal_type,
                                   logger, opt)

    val_dataset = PropSeqDataset(opt.val_caption_file,
                                 opt.visual_feature_folder,
                                 opt.dict_file, False, 'gt',
                                 logger, opt)

    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size,
                              shuffle=True, num_workers=opt.nthreads, collate_fn=collate_fn)

    val_loader = DataLoader(val_dataset, batch_size=opt.batch_size,
                            shuffle=False, num_workers=opt.nthreads, collate_fn=collate_fn)

    epoch = saved_info[opt.start_from_mode[:4]].get('epoch', 0)
    iteration = saved_info[opt.start_from_mode[:4]].get('iter', 0)
    best_val_score = saved_info[opt.start_from_mode[:4]].get('best_val_score', -1e5)
    val_result_history = saved_info['history'].get('val_result_history', {})
    loss_history = saved_info['history'].get('loss_history', {})
    lr_history = saved_info['history'].get('lr_history', {})
    opt.current_lr = vars(opt).get('current_lr', opt.lr)
    opt.vocab_size = train_loader.dataset.vocab_size

    # Build model
    model = EncoderDecoder(opt)  # 核心代码
    model.train()

    # Recover the parameters
    if opt.start_from and (not opt.pretrain):  #start_from = '' pretrain = False
        if opt.start_from_mode == 'best':
            model_pth = torch.load(os.path.join(save_folder, 'model-best-CE.pth'))
        elif opt.start_from_mode == 'best-RL':
            model_pth = torch.load(os.path.join(save_folder, 'model-best-RL.pth'))
        elif opt.start_from_mode == 'last':
            model_pth = torch.load(os.path.join(save_folder, 'model-last.pth'))
        logger.info('Loading pth from {}, iteration:{}'.format(save_folder, iteration))
        model.load_state_dict(model_pth['model'])

    # Load the pre-trained model
    if opt.pretrain and (not opt.start_from):
        logger.info('Load pre-trained parameters from {}'.format(opt.pretrain_path))
        if torch.cuda.is_available():
            model_pth = torch.load(opt.pretrain_path)
        else:
            model_pth = torch.load(opt.pretrain_path, map_location=torch.device('cpu'))
        model.load_state_dict(model_pth['model'])

    if torch.cuda.is_available():
        model.cuda()

    if opt.optimizer_type == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)  # weight_decay = 0
    else:
        optimizer = optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)

    if opt.start_from:
        optimizer.load_state_dict(model_pth['optimizer'])

    # print the args for debugging
    print_opt(opt, model, logger)
    print_alert_message('Strat training !', logger)

    loss_sum = np.zeros(3)  # (3,)  3 for loss, sample_score, greedy_score
    bad_video_num = 0
    start = time.time()

    # Epoch-level iteration
    while True:
        if True:
            # lr decay
            if epoch > opt.learning_rate_decay_start >= 0:  # learning_rate_decay_start=8  
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate ** frac
                opt.current_lr = opt.lr * decay_factor
            else:
                opt.current_lr = opt.lr
            utils.set_lr(optimizer, opt.current_lr)
Exemplo n.º 7
0
def main(opt):
    if opt.get('seed', -1) == -1:
        opt['seed'] = random.randint(1, 65534)
    utils.set_seed(opt['seed'])
    print('SEEEEEEEEEEED: %d' % opt['seed'])

    model_name = opt['encoder_type'] + '_' + opt['decoder_type']
    modality = (
        opt['modality'].upper() +
        's') if 'c3d' in opt['feats_m_name'][0] else opt['modality'].upper()

    if opt['na'] or opt['ar']:
        scope = get_scope2(opt)
    else:
        scope = get_scope(opt)

    opt["checkpoint_path"] = os.path.join(opt["checkpoint_path"],
                                          opt['checkpoint_path_name'],
                                          opt['dataset'], model_name, scope)

    opt['feats_a'] = get_dir(opt, 'feats_a_name', '/feats/')
    opt['feats_m'] = get_dir(opt, 'feats_m_name', '/feats/', pre=True)
    opt['feats_i'] = get_dir(opt, 'feats_i_name', '/feats/', pre=True)
    opt['feats_s'] = get_dir(opt, 'feats_s_name', '/feats/', pre=True)
    opt['feats_t'] = get_dir(opt, 'feats_t_name', '/feats/')  #, pre=True)
    '''
    opt['info_json'] = get_dir(opt, 'info_json_name', post='_%d.json' % opt['word_count_threshold'], prefix=opt['prefix'])
    opt['caption_json'] = get_dir(opt, 'caption_json_name', post='_%d.json' % opt['word_count_threshold'], prefix=opt['prefix'])
    opt['next_info_json'] = get_dir(opt, 'next_info_json_name', post='_%d.json' % opt['word_count_threshold'], prefix=opt['prefix'])
    opt['all_caption_json'] = get_dir(opt, 'all_caption_json_name', post='_%d.json' % opt['word_count_threshold'], prefix=opt['prefix'])
    opt['input_json'] = get_dir(opt, 'input_json_name', post='.json')
    '''
    opt['reference'] = get_dir(opt, 'reference_name', post='.pkl', pre=True)

    if opt.get('knowledge_distillation_with_bert', False):
        opt['bert_embeddings'] = get_dir(opt,
                                         'bert_embeddings_name',
                                         '/feats/',
                                         pre=True)

    opt['info_corpus'] = get_dir(opt,
                                 'info_corpus_name',
                                 post='_%d%s.pkl' %
                                 (opt['word_count_threshold'],
                                  '_%d' % opt['dist'] if opt['dist'] else ''),
                                 prefix=opt['prefix'])

    opt['corpus_pickle'] = get_dir(opt, 'corpus_name', post='.pkl')

    opt['vocab_size'] = len(
        pickle.load(
            open(opt['corpus_pickle'] if opt['others'] else opt['info_corpus'],
                 'rb'))['info']['itow'].keys())
    #opt['tag_size'] = len(json.load(open(opt["info_json"]))['ix_to_tag'].keys())

    opt_json = os.path.join(opt["checkpoint_path"], 'opt_info.json')
    if not os.path.exists(opt["checkpoint_path"]):
        os.makedirs(opt["checkpoint_path"])
    with open(opt_json, 'w') as f:
        json.dump(opt, f)
    print('save opt details to %s' % (opt_json))

    model = get_model(opt)

    print(model)
    print('| model {}'.format(model_name))
    print('| num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))
    print('use trigger: %d' % opt.get('use_trigger', 0))
    print('trigger level: %g' % opt.get('trigger_level', 0.25))
    print('dataloader random type: %s' %
          opt.get('random_type', 'segment_random'))
    print('k best model: %d' % opt.get('k_best_model', 10))
    print('teacher prob: %g' % opt.get('teacher_prob', 1.0))
    print('save model limit: %d' % opt.get('save_model_limit', -1))
    print('modality: %s' % opt.get('modality', 'ic'))
    print('equally sampling: %s' % opt.get('equally_sampling', False))
    print('n frames: %d' % opt['n_frames'])
    print('start eval epoch: %d' % opt['start_eval_epoch'])
    print('save_checkpoint_every: %d' % opt['save_checkpoint_every'])
    print('max_len: %d' % opt['max_len'])
    print('scheduled_sampling: {}'.format(opt['scheduled_sampling']))
    print('vocab_size: %d' % opt['vocab_size'])

    device = torch.device('cuda' if not opt['no_cuda'] else 'cpu')

    train_network_all(
        opt,
        model,
        device,
        first_evaluate_whole_folder=opt['first_evaluate_whole_folder'])
Exemplo n.º 8
0
def main(opt):
    '''Main Function'''
    if opt.collect:
        if not os.path.exists(opt.collect_path):
            os.makedirs(opt.collect_path)

    device = torch.device('cuda' if not opt.no_cuda else 'cpu')

    model, option = load(opt.model_path,
                         opt.model_name,
                         device,
                         mid_path='best')
    option.update(vars(opt))
    set_seed(option['seed'])

    if not opt.nt:
        #teacher_path = os.path.join(option["checkpoint_path"].replace('NARFormer', 'ARFormer') + '_SS1_0_70')
        #teacher_name = 'teacher.pth.tar'
        #teacher_model, teacher_option = load(teacher_path, teacher_name, device, mid_path='', from_checkpoint=True)

        checkpoint = torch.load(opt.teacher_path)
        teacher_option = checkpoint['settings']
        teacher_model = get_model(teacher_option)
        teacher_model.load_state_dict(checkpoint['state_dict'])
        teacher_model.to(device)

        assert teacher_option['vocab_size'] == option['vocab_size']

        #dict_mapping = get_dict_mapping(option, teacher_option)
        dict_mapping = {}
    else:
        teacher_model = None
        dict_mapping = {}
    '''
    model = get_model(option)
    pth = os.path.join(opt.model_path, 'tmp_models')
    vali_loader = get_loader(option, mode='validate')
    test_loader = get_loader(option, mode='test')
    vocab = vali_loader.dataset.get_vocab()
    logger = CsvLogger(
        filepath=pth, 
        filename='evaluate.csv', 
        fieldsnames=['epoch', 'split', 'Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4', 'METEOR', 'ROUGE_L', 'CIDEr', 'Sum', 'lbs', 'i', 'ba']
        )
    for file in os.listdir(pth):
        if '.pth.tar' not in file:
            continue
        epoch = file.split('_')[1]
        checkpoint = torch.load(os.path.join(pth, file))
        model.load_state_dict(checkpoint['state_dict'])
        model.to(device)

        best = 0
        best_info = ()
        for lbs in range(1, 11):
            for i in [1, 3, 5, 10]:
                option['length_beam_size'] = lbs
                option['iterations'] = i
                metric = run_eval(option, model, None, test_loader, vocab, device, json_path=opt.json_path, json_name=opt.json_name, print_sent=opt.print_sent)
                metric.pop('loss')
                metric['lbs'] = lbs
                metric['ba'] = opt.beam_alpha
                metric['i'] = i
                metric['split'] = 'test'
                metric['epoch'] = epoch
                logger.write(metric)
                if metric['Sum'] > best:
                    best = metric['Sum']
                    best_info = (lbs, i)
                print(lbs, i, metric['Sum'], best)


    '''
    '''
    # rec length predicted results
    rec = {}
    for data in tqdm(loader, ncols=150, leave=False):
        with torch.no_grad():
            results = get_forword_results(option, model, data, device=device, only_data=False)
            for i in range(results['pred_length'].size(0)):
                res = results['pred_length'][i].topk(5)[1].tolist()
                for item in res:
                    rec[item] = rec.get(item, 0) + 1
    for i in range(50):
        if i in rec.keys():
            print(i, rec[i])
    '''
    if opt.plot:
        plot(option, opt, model, loader, vocab, device, teacher_model,
             dict_mapping)
    elif opt.loop:
        loader = get_loader(option, mode=opt.em, print_info=True)
        vocab = loader.dataset.get_vocab()
        #loop_iterations(option, opt, model, loader, vocab, device, teacher_model, dict_mapping)
        loop_length_beam(option, opt, model, loader, vocab, device,
                         teacher_model, dict_mapping)
        #loop_iterations(option, opt, model, device, teacher_model, dict_mapping)
    elif opt.category:
        loop_category(option, opt, model, device, teacher_model, dict_mapping)
    else:
        loader = get_loader(option,
                            mode=opt.em,
                            print_info=True,
                            specific=opt.specific)
        vocab = loader.dataset.get_vocab()
        filename = '%s_%s_%s_i%db%da%03d%s.pkl' % (
            option['dataset'], option['method'],
            ('%s' % ('AE' if opt.nv_scale == 100 else '')) + opt.paradigm,
            opt.iterations, opt.length_beam_size, int(
                100 * opt.beam_alpha), '_all' if opt.em == 'all' else '')
        metric = run_eval(option,
                          model,
                          None,
                          loader,
                          vocab,
                          device,
                          json_path=opt.json_path,
                          json_name=opt.json_name,
                          print_sent=opt.print_sent,
                          teacher_model=teacher_model,
                          length_crit=torch.nn.SmoothL1Loss(),
                          dict_mapping=dict_mapping,
                          analyze=opt.analyze,
                          collect_best_candidate_iterative_results=True
                          if opt.collect else False,
                          collect_path=os.path.join(opt.collect_path,
                                                    filename),
                          no_score=opt.ns,
                          write_time=opt.write_time)
        #collect_path=os.path.join(opt.collect_path, opt.collect),
        print(metric)