Exemple #1
0
def checkpoint_stage(stage, result):
    checkpoint = utils.get_checkpoint().get(stage, {})
    checkpoint.update({'result': result})
    utils.update_checkpoint({
        'stage': stages.index(stage),
        stage: checkpoint,
        })
Exemple #2
0
def set_manual_stages(params):
    start_stage_idx = len(stages) - 1
    checkpoint = utils.get_checkpoint()['params']

    # For each stage, check if any of the parameters have changed since the last run
    changed_stages = set()
    for stage,stage_params in params.items():
        diff = {}
        for key,value in stage_params.items():
            if checkpoint.get(stage) and checkpoint.get(stage).get(key) != value:
                # the value differs from the checkpoint
                diff[key] = (value, checkpoint.get(stage).get(key))

        if diff:
            logger.info("parameters differ from checkpoint in stage %s: %s", stage, diff)
            changed_stages.add(stage)

    # For stage, mark is as changed if any of its dependencies changed
    for stage in stages:
        for dependency in dependencies.get(stage, []):
            if stages[dependency] in changed_stages:
                changed_stages.add(stage)
    
    # Finally, add the required stages to the manual stages
    if changed_stages:
        start_stage_idx = min([ stages.index(stage) for stage in changed_stages ])
    stage_required.manual_stages = range(start_stage_idx, len(stages))
Exemple #3
0
def stage_required(stage):
    # check that 'stage-1' has been checkpointed
    stage_index = stages.index(stage)
    required = (utils.get_checkpoint()['stage'] < stage_index)

    if stage_required.manual_stages:
        if stage_index > max(stage_required.manual_stages):
            sys.exit(0)
        elif stage_index in stage_required.manual_stages:
            return True
        else:
            return False
    return required
Exemple #4
0
def do_linalg(parameters, sieve_result):
    result = {}
    # Check the rare case that factorization was completed in msieve's filtering
    if utils.get_checkpoint().get('trial_division') != None:
        result['duration'] = 0
    elif stage_required('linalg'):
        linalg_start = time.time()
        linalg.run(parameters)
        linalg_finish = time.time()
        result['duration'] = linalg_finish - linalg_start
        logger.info("\tLinalg in %s", utils.str_time(result['duration']))

        checkpoint_stage('linalg', result)
    else:
        result = load_stage('linalg')

    post_linalg = parameters.myparams({'post_linalg': None}, ['commands']).get('post_linalg')
    if post_linalg != None:
        logger.info('Post-linalg command %s', post_linalg)
        utils.run_command(post_linalg, logger=logger)

    return result
Exemple #5
0
def load_stage(stage):
    return utils.get_checkpoint().get(stage, {}).get('result')
def main():

    # get args
    parser = argparse.ArgumentParser(description="Im2Latex Training Program")
    # parser.add_argument('--path', required=True, help='root of the model')

    # model args
    parser.add_argument("--emb_dim",
                        type=int,
                        default=80,
                        help="Embedding size")
    parser.add_argument("--dec_rnn_h",
                        type=int,
                        default=512,
                        help="The hidden state of the decoder RNN")
    parser.add_argument("--data_path",
                        type=str,
                        default="/root/private/im2latex/data/",
                        help="The dataset's dir")
    parser.add_argument("--add_position_features",
                        action='store_true',
                        default=False,
                        help="Use position embeddings or not")
    # training args
    parser.add_argument("--max_len",
                        type=int,
                        default=150,
                        help="Max size of formula")
    parser.add_argument("--dropout",
                        type=float,
                        default=0.,
                        help="Dropout probility")
    parser.add_argument("--cuda",
                        action='store_true',
                        default=True,
                        help="Use cuda or not")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--epoches", type=int, default=200)
    parser.add_argument("--lr", type=float, default=3e-4, help="Learning Rate")
    parser.add_argument("--min_lr",
                        type=float,
                        default=3e-5,
                        help="Learning Rate")
    parser.add_argument("--sample_method",
                        type=str,
                        default="teacher_forcing",
                        choices=('teacher_forcing', 'exp', 'inv_sigmoid'),
                        help="The method to schedule sampling")
    parser.add_argument("--decay_k", type=float, default=1.)

    parser.add_argument("--lr_decay",
                        type=float,
                        default=0.5,
                        help="Learning Rate Decay Rate")
    parser.add_argument("--lr_patience",
                        type=int,
                        default=3,
                        help="Learning Rate Decay Patience")
    parser.add_argument("--clip",
                        type=float,
                        default=2.0,
                        help="The max gradient norm")
    parser.add_argument("--save_dir",
                        type=str,
                        default="./ckpts",
                        help="The dir to save checkpoints")
    parser.add_argument("--print_freq",
                        type=int,
                        default=100,
                        help="The frequency to print message")
    parser.add_argument("--seed",
                        type=int,
                        default=2020,
                        help="The random seed for reproducing ")
    parser.add_argument("--from_check_point",
                        action='store_true',
                        default=False,
                        help="Training from checkpoint or not")
    parser.add_argument("--batch_size_per_gpu", type=int, default=16)
    parser.add_argument("--gpu_num", type=int, default=4)
    device_ids = [0, 1, 2, 3]

    args = parser.parse_args()
    max_epoch = args.epoches
    from_check_point = args.from_check_point
    if from_check_point:
        checkpoint_path = get_checkpoint(args.save_dir)
        checkpoint = torch.load(checkpoint_path)
        args = checkpoint['args']
    print("Training args:", args)

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # Building vocab
    print("Load vocab...")
    vocab = load_vocab(args.data_path)

    use_cuda = True if args.cuda and torch.cuda.is_available() else False
    print(use_cuda)
    device = torch.device("cuda" if use_cuda else "cpu")

    # data loader
    print("Construct data loader...")
    # train_loader = DataLoader(
    #     Im2LatexDataset(args.data_path, 'train', args.max_len),
    #     batch_size=args.batch_size,
    #     collate_fn=partial(collate_fn, vocab.token2idx),
    #     pin_memory=True if use_cuda else False,
    #     num_workers=4)
    train_loader = DataLoader(
        Im2LatexDataset(args.data_path, 'train', args.max_len),
        batch_size=args.batch_size_per_gpu * args.gpu_num,
        collate_fn=partial(collate_fn, vocab.token2idx),
        pin_memory=True if use_cuda else False,
        num_workers=2)
    # val_loader = DataLoader(
    #     Im2LatexDataset(args.data_path, 'validate', args.max_len),
    #     batch_size=args.batch_size,
    #     collate_fn=partial(collate_fn, vocab.token2idx),
    #     pin_memory=True if use_cuda else False,
    #     num_workers=4)
    val_loader = DataLoader(Im2LatexDataset(args.data_path, 'validate',
                                            args.max_len),
                            batch_size=args.batch_size_per_gpu * args.gpu_num,
                            collate_fn=partial(collate_fn, vocab.token2idx),
                            pin_memory=True if use_cuda else False,
                            num_workers=2)

    # construct model
    print("Construct model")
    vocab_size = len(vocab)
    model = Im2LatexModel(vocab_size,
                          args.emb_dim,
                          args.dec_rnn_h,
                          add_pos_feat=args.add_position_features,
                          dropout=args.dropout)
    model = nn.DataParallel(model, device_ids=device_ids)
    # model = model.
    model = model.cuda()
    print("Model Settings:")
    print(model)

    # construct optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    lr_scheduler = ReduceLROnPlateau(optimizer,
                                     "min",
                                     factor=args.lr_decay,
                                     patience=args.lr_patience,
                                     verbose=True,
                                     min_lr=args.min_lr)

    if from_check_point:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        lr_scheduler.load_state_dict(checkpoint['lr_sche'])
        # init trainer from checkpoint
        trainer = Trainer(optimizer,
                          model,
                          lr_scheduler,
                          train_loader,
                          val_loader,
                          args,
                          use_cuda=use_cuda,
                          init_epoch=epoch,
                          last_epoch=max_epoch)
    else:
        trainer = Trainer(optimizer,
                          model,
                          lr_scheduler,
                          train_loader,
                          val_loader,
                          args,
                          use_cuda=use_cuda,
                          init_epoch=1,
                          last_epoch=args.epoches)
    # begin training
    trainer.train()
def run_rnn_rcnn(args):
    dataset_dir = utils.choose_data(args)
    with tf.device('/cpu:0'):
        dataset = utils.load_dataset(dataset_dir)
    print "Using checkpoint directory: {0}".format(args.ckpt_dir)
    reward_output_file = os.path.join(args.ckpt_dir, "rewards")
    accuracy_output_file = os.path.join(args.ckpt_dir, "accuracies")
    
    model = utils.choose_model(args) # pass in necessary model parameters
    print "Running {0} model for {1} epochs.".format(args.model, args.num_epochs)

    global_step = tf.Variable(0, trainable=False, name='global_step')
    saver = tf.train.Saver(max_to_keep=args.num_epochs)

    with tf.Session(config=GPU_CONFIG) as session:
        print "Inititialized TF Session!"

        # Checkpoint
        i_stopped, found_ckpt = utils.get_checkpoint(args, session, saver)
        # Summary Writer
        file_writer = tf.summary.FileWriter(args.ckpt_dir, graph=session.graph, max_queue=10, flush_secs=30)
		# Val or Test set accuracie

        # Make computational graph
        if args.train == "train" and not found_ckpt:
            init_op = tf.global_variables_initializer()
            init_op.run()
        else:
            if not found_ckpt:
                print "No checkpoint found for test or validation!"
                return

        if "pretrained" in args.model:
            init_fn = tf.contrib.framework.assign_from_checkpoint_fn(
                                model_path='/data/yolo/YOLO_small.ckpt',
                                var_list=model.variables_to_restore)
            init_fn(session)

        if "neg_l1" in args.model:
            model.add_loss_op('negative_l1_dist')
        elif "iou" in args.model:
            model.add_loss_op('iou')
            model.add_summary_op()

        # if "pretrained" in args.model:

        if args.train == 'train':
            for i in xrange(i_stopped, args.num_epochs):
                print "Running epoch ({0})...".format(i)

                with tf.device('/cpu:0'):
                    batched_data, batched_labels, batched_seq_lens,  batched_bbox = utils.make_batches(dataset, batch_size=BATCH_SIZE)

                rewards, accuracies = run_epoch(args, model, session, batched_data, batched_labels, batched_seq_lens,  batched_bbox,saver, 
                        file_writer, i)
                with open(reward_output_file, "a") as reward_file:
                    reward_file.write("\n".join([r.astype('|S6') for r in rewards]) + "\n")
                with open(accuracy_output_file, "a") as accuracy_file:
                    accuracy_file.write("\n".join([a.astype('|S6') for a in accuracies]) + "\n")
        if args.train == 'test':
            # model.add_error_op()
            # model.add_summary_op()
            batched_data, batched_labels, batched_seq_lens,  batched_bbox = utils.make_batches(dataset, batch_size=BATCH_SIZE)
            run_epoch(args, model, session, batched_data, batched_labels, batched_seq_lens,  batched_bbox,saver, 
                        file_writer, 1)
    print "Output weights are same: " + str(output)
    print "Output bias weights are same: " + str(output_bias)
    print "-------------------------------------------------------------------"


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--nature", default=True, type=bool)
    args = parser.parse_args()

    epoch = 0
    everywhere_all_the_time = True
    while checkpoints_all_exist(epoch):
        nets = []
        for checkpoint_dir in CHECKPOINT_FOLDERS:
            checkpoint = utils.get_checkpoint(checkpoint_dir, epoch)
            nets.append(checkpoint['state_dict'])
        '''
		Print results
		'''
        print "EPOCH: " + str(epoch)
        conv1 = True
        conv2 = True
        fc1 = True
        output = True
        conv1_bias = True
        conv2_bias = True
        fc1_bias = True
        output_bias = True
        if args.nature:
            conv3 = True