Пример #1
0
def train_one_epoch(curr_iter, curr_epoch, train_dataloader, model, body_model,
                    prediction, optimizer, loss_function, smplify,
                    weakly_supervise, use_precal_smplify, logger, batch_size,
                    device, dtype):

    iterator = enumerate(train_dataloader)
    pose2rot = False if prediction == 'rotmat' else True

    with torch.autograd.enable_grad():
        with torch.autograd.set_detect_anomaly(True):
            for i_iter in range(len(train_dataloader)):
                _, batch = next(iterator)
                images, keypoints_3d_gt, proj_matricies = prepare_batch(
                    batch, device)

                if images.shape[0] != batch_size:
                    # In this case data batch size is not compatible
                    # with SMPL model batch size
                    if weakly_supervise:
                        continue

                pred_pose, pred_betas, pred_global_orient, pred_vertices \
                    = model(images, proj_matricies, batch)

                pred_output = body_model(betas=pred_betas,
                                         body_pose=pred_pose,
                                         global_orient=pred_global_orient,
                                         pose2rot=pose2rot)

                opt_pose = torch.from_numpy(batch['smplify_pose']).to(
                    device=device, dtype=dtype)
                opt_betas = torch.from_numpy(batch['smplify_shape']).to(
                    device=device, dtype=dtype)
                opt_global_orient = torch.from_numpy(
                    batch['smplify_global_orient']).to(device=device,
                                                       dtype=dtype)

                opt_output = body_model(betas=opt_betas,
                                        body_pose=opt_pose,
                                        global_orient=opt_global_orient,
                                        pose2rot=True)

                print_log = ((curr_iter + 1) %
                             len(train_dataloader)) % logger.write_freq == 0
                loss, loss_dict = loss_function(pred_output,
                                                keypoints_3d_gt,
                                                proj_matricies,
                                                opt_output=opt_output,
                                                print_log=print_log,
                                                has_smpl=batch['has_smpl'])

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                curr_iter = curr_iter + 1
                logger(loss_dict, curr_iter)

    return curr_iter
Пример #2
0
def decode():
    # Load model config
    config = load_config(FLAGS)

    # Load source data to decode
    test_set = TextIterator(source=config['decode_input'],
                            batch_size=config['decode_batch_size'],
                            source_dict=config['source_vocabulary'],
                            maxlen=None,
                            n_words_source=config['num_encoder_symbols'])

    # Load inverse dictionary used in decoding
    target_inverse_dict = data_utils.load_inverse_dict(
        config['target_vocabulary'])

    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Reload existing checkpoint
        model = load_model(sess, config)
        try:
            print('Decoding {}..'.format(FLAGS.decode_input))
            if FLAGS.write_n_best:
                fout = [data_utils.fopen(("%s_%d" % (FLAGS.decode_output, k)), 'w') \
                        for k in range(FLAGS.beam_width)]
            else:
                fout = [data_utils.fopen(FLAGS.decode_output, 'w')]

            for idx, source_seq in enumerate(test_set.next()):
                print('Source', source_seq)
                source, source_len = prepare_batch(source_seq)
                print('Source', source, 'Source Len', source_len)
                # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1]
                # BeamSearchDecoder; [batch_size, max_time_step, beam_width]
                predicted_ids = model.predict(sess,
                                              encoder_inputs=source,
                                              encoder_inputs_length=source_len)
                print(predicted_ids)
                # Write decoding results
                for k, f in reversed(list(enumerate(fout))):
                    for seq in predicted_ids:
                        f.write(
                            str(
                                data_utils.seq2words(
                                    seq[:, k], target_inverse_dict)) + '\n')
                    if not FLAGS.write_n_best:
                        break
                print('{}th line decoded'.format(idx *
                                                 FLAGS.decode_batch_size))

            print('Decoding terminated')
        except IOError:
            pass
        finally:
            [f.close() for f in fout]
Пример #3
0
def decode(config):
    model, config = load_model(config)
    # Load source data to decode
    test_set = TextIterator(
        source=config['decode_input'],
        source_dict=config['src_vocab'],
        batch_size=config['batch_size'],
        maxlen=None,
        n_words_source=config['num_enc_symbols'],
        shuffle_each_epoch=False,
        sort_by_length=False,
    )
    target_inv_dict = load_inv_dict(config['tgt_vocab'])

    lines = 0
    max_decode_step = config['max_decode_step']
    print 'Decoding starts..'
    with fopen(config['decode_output'], 'w') as fout:
        for idx, source_seq in enumerate(test_set):
            source, source_len = prepare_batch(source_seq)

            preds_prev = torch.zeros(len(source), max_decode_step).long()
            preds_prev[:, 0] += data_utils.start_token
            preds = torch.zeros(len(source), max_decode_step).long()

            if use_cuda:
                source = Variable(source.cuda())
                source_len = Variable(source_len.cuda())
                preds_prev = Variable(preds_prev.cuda())
                preds = preds.cuda()
            else:
                source = Variable(source)
                source_len = Variable(source_len)
                preds_prev = Variable(preds_prev)

            states, memories = model.encode(source, source_len)

            for t in xrange(max_decode_step):
                # logits: [batch_size x max_decode_step, tgt_vocab_size]
                _, logits = model.decode(preds_prev[:, :t + 1], states,
                                         memories)
                # outputs: [batch_size, max_decode_step]
                outputs = torch.max(logits, dim=1)[1].view(len(source), -1)
                preds[:, t] = outputs[:, t].data
                if t < max_decode_step - 1:
                    preds_prev[:, t + 1] = outputs[:, t]
            for i in xrange(len(preds)):
                fout.write(str(seq2words(preds[i], target_inv_dict)) + '\n')
                fout.flush()

            lines += source.size(0)
            print '  {}th line decoded'.format(lines)
        print 'Decoding terminated'
Пример #4
0
def predict(config):
    tf.reset_default_graph()
    from data.data_iterator import TextIterator, Butian_TextIterator
    valid_set = TextIterator(source=config['valid'],
                             batch_size=config['batch_size'],
                             source_dict=config['source_vocabulary'])

    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=tf.GPUOptions(
                                              allow_growth=True))) as sess:
        model = Detector(config, 'test')
        model.restore(sess, config['save_path'])
        #model.restore_specific(sess, config['save_path'])
        _acc = 0
        _loss = 0
        _num = 0
        prediction = []
        all_labels = []
        all_pred = []
        for idx, sources in enumerate(valid_set):
            source_seq = sources[0]
            label = sources[1]
            sources, labels = prepare_batch(source_seq,
                                            label,
                                            max_batch=config['max_batch'],
                                            maxlen=config['maxlen'],
                                            stride=config['stride'],
                                            batch_size=config['batch_size'])
            for source, label in zip(sources, labels):
                pred, logit, acc, loss = model.predict(sess, source, label)
                prediction.extend(logit)
                #all_labels.extend(label_binarize(label,classes=["0","1","2"]))
                all_labels.extend(list(map(int, label)))
                all_pred.extend(list(map(int, pred)))
                #print("step {}, size {}, acc {:g}, softmax_loss {:g}".format(model.global_step.eval(), pred.shape, acc, loss))
                _acc += acc * pred.shape[0]
                _loss += loss * pred.shape[0]
                _num += pred.shape[0]
        print(config['save_path'])
        print("step {}, acc {:g}, softmax_loss {:g}".format(
            model.global_step.eval(), _acc / _num, _loss))
        prediction = np.stack(prediction)
        all_labels = np.stack(all_labels)
        all_pred = np.stack(all_pred)
        return prediction, label_binarize(all_labels, classes=[0, 1, 2]), [
            accuracy_score(all_labels, all_pred),
            precision_score(all_labels, all_pred),
            recall_score(all_labels, all_pred)
        ]
Пример #5
0
def decode():
    # Load model config
    config = load_config(FLAGS)
    print(config['source_vocabulary'])
    # Load source data to decode
    test_set = TextIterator(source=config['decode_input'],
                            batch_size=config['decode_batch_size'],
                            source_dict=config['source_vocabulary'],)
    # Load inverse dictionary used in decoding
    
    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=FLAGS.allow_soft_placement, 
        log_device_placement=FLAGS.log_device_placement, gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Reload existing checkpoint
        model = load_model(sess, config)
        try:
            if FLAGS.write_n_best:
                fout = [data_utils.fopen(("%s_%d" % (FLAGS.decode_output, k)), 'w') \
                        for k in range(FLAGS.beam_width)]
            else:
                fout = [data_utils.fopen(FLAGS.decode_output, 'w')]
            for source_seq, label in test_set:
                # label = test_labels[idx]
                source, source_len = prepare_batch(source_seq, batch_size=config['decode_batch_size'], stride = config['max_seq_length'],maxlen=config['max_seq_length'])
                # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1]
                # BeamSearchDecoder; [batch_size, max_time_step, beam_width]
                predicted_ids = model.predict(sess, encoder_inputs=source, 
                                              encoder_inputs_length=source_len)
                   
                # Write decoding results
                for k, f in reversed(list(enumerate(fout))):
                    f.write(str(source_seq)+'\t\t\t')
                    res = []
                    for seq in predicted_ids:
                        res.append(list(seq[:,k]))
                    f.write(str(res)+'\n')
                    if not FLAGS.write_n_best:
                        break
                
            print('Decoding terminated')
        except IOError:
            pass
        finally:
            [f.close() for f in fout]
Пример #6
0
def eval_model(curr_epoch,
               val_dataloader,
               model,
               body_model,
               prediction,
               loss_function,
               logger,
               batch_size,
               device,
               dtype,
               data_name='h36m',
               mpjpe_list=[],
               recone_list=[],
               mpjpe=0,
               recone=0,
               val_iter=1):

    print('Evaluate the model ... ')

    iterator = enumerate(val_dataloader)
    pose2rot = False if prediction == 'rotmat' else True

    model.eval()
    with torch.no_grad():
        for t in range(len(val_dataloader)):
            _, batch = next(iterator)
            images, keypoints_3d_gt, proj_matricies = prepare_batch(
                batch, device)

            if images.shape[0] != batch_size:
                # In this case data batch size is not compatible with SMPL model batch size
                continue

            pred_pose, pred_betas, pred_global_orient, pred_vertices \
                = model(images, proj_matricies, batch)

            pred_output = body_model(betas=pred_betas,
                                     body_pose=pred_pose,
                                     global_orient=pred_global_orient,
                                     pose2rot=pose2rot)

            _mpjpe, _recone = loss_function.eval(pred_output, keypoints_3d_gt)
            mpjpe = mpjpe + _mpjpe
            recone = recone + _recone
            mean_mpjpe = mpjpe / val_iter
            mean_recone = recone / val_iter
            val_iter = val_iter + 1

    logger(
        {
            '%s Val MPJPE' % data_name: mean_mpjpe,
            '%s Val RECONE' % data_name: mean_recone
        },
        curr_epoch,
        is_val=True)
    model.train()

    print('Evaluate results: MPJPE: %.3f mm  |  RECONE: %.3f mm' %
          (mean_mpjpe, mean_recone))

    return mean_mpjpe, mean_recone
Пример #7
0
def train():
    # Load model config
    config = load_config(FLAGS)

    # Load source data to decode
    test_set = TextIterator(source=config['source_train_data'],
                            batch_size=config['decode_batch_size'],
                            source_dict=config['source_vocabulary'],
                            maxlen=None,
                            n_words_source=config['num_encoder_symbols'])
    #test_set, test_labels = data_utils.load_data('test')
    #valid_set, valid_labels = data_utils.load_data('valid')
    # Load inverse dictionary used in decoding

    # Initiate TF session
    with tf.Session(config=tf.ConfigProto(
            allow_soft_placement=FLAGS.allow_soft_placement,
            log_device_placement=FLAGS.log_device_placement,
            gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Reload existing checkpoint
        model = load_model(sess, config)

        # Create a log writer object
        log_writer = tf.summary.FileWriter(FLAGS.model_path, graph=sess.graph)

        step_time, loss = 0.0, 0.0
        words_seen, sents_seen = 0, 0
        start_time = time.time()

        for epoch_idx in range(FLAGS.max_epochs):
            if model.global_epoch_step.eval() >= FLAGS.max_epochs:
                print('Training is already complete.', \
                      'current epoch:{}, max epoch:{}'.format(model.global_epoch_step.eval(), FLAGS.max_epochs))
                break
            for idx, source_seq in enumerate(test_set):
                source, source_len = prepare_batch(source_seq)
                # predicted_ids: GreedyDecoder; [batch_size, max_time_step, 1]
                # BeamSearchDecoder; [batch_size, max_time_step, beam_width]
                step_loss, summary, predicted_ids = model.train(
                    sess,
                    encoder_inputs=source,
                    encoder_inputs_length=source_len)
                loss += float(step_loss) / FLAGS.display_freq
                words_seen += float(np.sum(source_len + target_len))
                sents_seen += float(source.shape[0])  # batch_size
                if model.global_step.eval() % FLAGS.display_freq == 0:

                    avg_perplexity = math.exp(
                        float(loss)) if loss < 300 else float("inf")

                    time_elapsed = time.time() - start_time
                    step_time = time_elapsed / FLAGS.display_freq

                    words_per_sec = words_seen / time_elapsed
                    sents_per_sec = sents_seen / time_elapsed

                    print('Epoch ', model.global_epoch_step.eval(), 'Step ', model.global_step.eval(), \
                          'Perplexity {0:.2f}'.format(avg_perplexity), 'Step-time ', step_time, \
                          '{0:.2f} sents/s'.format(sents_per_sec), '{0:.2f} words/s'.format(words_per_sec))

                    loss = 0
                    words_seen = 0
                    sents_seen = 0
                    start_time = time.time()

                    # Record training summary for the current batch
                    log_writer.add_summary(summary, model.global_step.eval())

                    # Write decoding results
                    print(
                        str(source_seq),
                        '\t',
                    )
                    for i in range(len(source_seq)):
                        res = []
                        for seq in predicted_ids:
                            res.append(list(seq[:, i]))
                        print(str(res))
                    print('  {}th line decoded'.format(
                        idx * FLAGS.decode_batch_size))
                if valid_set and model.global_step.eval(
                ) % FLAGS.valid_freq == 0:
                    print('Validation step')
                    valid_loss = 0.0
                    valid_sents_seen = 0
                    for idx, source_seq in enumerate(valid_set):
                        label = test_labels[idx]
                        source, source_len = prepare_batch(source_seq)
                        step_loss, summary, predicted_ids = model.train(
                            sess,
                            encoder_inputs=source,
                            encoder_inputs_length=source_len)
                        batch_size = source.shape[0]
                        valid_loss += step_loss * batch_size
                        valid_sents_seen += batch_size
                    valid_loss = valid_loss / valid_sents_seen
                    print('Valid perplexity: {0:.2f}'.format(
                        math.exp(valid_loss)))

                # Save the model checkpoint
                if model.global_step.eval() % FLAGS.save_freq == 0:
                    print('Saving the model..')
                    checkpoint_path = os.path.join(FLAGS.model_dir,
                                                   FLAGS.model_name)
                    model.save(sess,
                               checkpoint_path,
                               global_step=model.global_step)
                    json.dump(model.config,
                              open(
                                  '%s-%d.json' %
                                  (checkpoint_path, model.global_step.eval()),
                                  'wb'),
                              indent=2)
Пример #8
0
def main(cfg, args, val_iter=1):
    render = True
    output_fldr = args.output

    batch_size = cfg.TRAIN.BATCH_SIZE

    # Build neural network model
    model = build_model(cfg)

    # Initialize weight if resume training
    if osp.exists(cfg.TRAIN.INIT_WEIGHT):
        model.load_state_dict(torch.load(cfg.TRAIN.INIT_WEIGHT)['state_dict'])
        print("Evaluating model loaded ...")

    model.eval()

    # Model Summary
    total_params = sum(p.numel() for p in model.parameters())
    print('Total number of parameters is {}'.format(total_params))

    # Build SMPL model
    body_model = build_body_model(cfg, render=render)

    loss_function = build_loss_function(cfg)
    J_regressor = torch.from_numpy(np.load(
        constants.JOINT_REGRESSOR_H36M)).float()

    # Build dataloader
    if cfg.DATASET.TYPE == 'human36m':
        val_dloader = setup_new_human36m_dataloaders(cfg)
        # val_dloader = setup_human36m_dataloaders(cfg)
    else:
        val_dloader = setup_mpi_dataloaders(cfg)

    print('==> Data loaded...')

    mpjpe, pa_mpjpe, pck, pa_pck, auc, pa_auc = 0, 0, 0, 0, 0, 0
    iterator = enumerate(val_dloader)
    with torch.no_grad():
        with tqdm(total=len(val_dloader)) as prog_bar:
            for t in range(len(val_dloader)):
                _, batch = next(iterator)

                if batch['frame'][0] > 1500 or batch['frame'][0] < 1000:
                    continue

                images, keypoints_3d_gt, proj_matricies = prepare_batch(
                    batch, cfg.DEVICE)

                if images.shape[0] != batch_size:
                    # In this case data batch size is not compatible with SMPL model batch size
                    continue

                # images = images[:, 0].expand(1, 4, 3, 224, 224)
                # proj_matricies = proj_matricies[:, 0].expand(1, 4, 3, 4)

                pred_pose, pred_betas, pred_global_orient, pred_vertices \
                    = model(images, proj_matricies, batch)

                pred_output = body_model(betas=pred_betas,
                                         body_pose=pred_pose,
                                         global_orient=pred_global_orient,
                                         pose2rot=False)

                if cfg.DATASET.TYPE == 'mpi-inf-3dhp':
                    # For Evaluating MPI-INF-3DHP dataset, PCK / AUC are the another error metric
                    _mpjpe, _pa_mpjpe, _pck, _pa_pck, _auc, _pa_auc = compute_error(
                        pred_output, keypoints_3d_gt, J_regressor)
                    pck += _pck
                    pa_pck += _pa_pck
                    auc += _auc
                    pa_auc += _pa_auc

                else:
                    _mpjpe, _pa_mpjpe = loss_function.eval(
                        pred_output, keypoints_3d_gt)

                mpjpe = mpjpe + _mpjpe
                pa_mpjpe = pa_mpjpe + _pa_mpjpe
                mean_mpjpe = mpjpe / val_iter
                mean_pa_mpjpe = pa_mpjpe / val_iter

                val_iter = val_iter + 1

                msg = 'MPJPE = %.3f,   PA MPJPE = %.3f' % (mean_mpjpe,
                                                           mean_pa_mpjpe)
                prog_bar.set_postfix_str(msg)
                prog_bar.update(1)
                prog_bar.refresh()

                if render:
                    # generate_figure(batch['org_cameras'], pred_output, body_model,
                    #                 batch['org_images'], keypoints_3d_gt, iters=t)

                    tmp_generate_figure(batch['org_cameras'], pred_output,
                                        body_model, batch['org_images'],
                                        keypoints_3d_gt, batch['action'][0],
                                        batch['frame'][0])

                # opt_pose = torch.from_numpy(batch['smplify_pose']).to(device=pred_pose.device)
                # opt_betas = torch.from_numpy(batch['smplify_shape']).to(device=pred_pose.device)
                # opt_global_orient = torch.from_numpy(batch['smplify_global_orient']).to(device=pred_pose.device)

                # opt_output = body_model(betas=opt_betas, body_pose=opt_pose, global_orient=opt_global_orient,
                #                         pose2rot=True)

                # if render:
                #     generate_figure(batch['org_cameras'], opt_output, body_model,
                #                     batch['org_images'], keypoints_3d_gt, iters=t)

                # import pdb; pdb.set_trace()

    print('Evaluate results: MPJPE: %.3f mm  |  RECONE: %.3f mm' %
          (mean_mpjpe, mean_pa_mpjpe))
    if cfg.DATASET.TYPE == 'mpi-inf-3dhp':
        print('PCK : %.1f | AUC : %.1f |   PCK (PA) : %.1f | AUC (PA) : %.1f' %
              (pck / val_iter * 100, auc / val_iter * 100,
               pa_pck / val_iter * 100, pa_auc / val_iter * 100))
Пример #9
0
def train(config, maxs):
    check = False
    tf.reset_default_graph()
    from data.data_iterator import TextIterator, Butian_TextIterator
    test_set = TextIterator(source=config['input'],
                            batch_size=config['batch_size'],
                            source_dict=config['source_vocabulary'],
                            shuffle_each_epoch=True)
    valid_set = TextIterator(source=config['valid'],
                             batch_size=config['batch_size'],
                             source_dict=config['source_vocabulary'],
                             shuffle_each_epoch=False)
    with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                          log_device_placement=False,
                                          gpu_options=tf.GPUOptions(
                                              allow_growth=True))) as sess:
        model = Detector(config, 'pretrain')
        #model.restore(sess, config['save_path'])
        model.restore_specific(
            sess,
            '/home/dbtest/lan/Malware/model/detector/gru_False_2att/detector-9650'
        )
        for epoch_idx in range(config['max_epochs']):
            #print(epoch_idx)
            for idx, train_sources in enumerate(test_set):
                source_seq = train_sources[0]
                label = train_sources[1]
                sources, labels = prepare_batch(
                    source_seq,
                    label,
                    max_batch=config['max_batch'],
                    maxlen=config['maxlen'],
                    stride=config['stride'],
                    batch_size=config['batch_size'])
                for source, label in zip(sources, labels):
                    loss, acc, pred, logit, summary, _, _labels, _1 = model.pretrain(
                        sess, source, label)
                    #print("step {}, size {}, acc {:g}, softmax_loss {:g}".format(model.global_step.eval(), pred.shape, acc, loss))
                    if (model.global_step.eval() % 50 == 0):
                        _acc = 0
                        _loss = 0
                        _num = 0
                        for idx, test_sources in enumerate(valid_set):
                            sub_source_seq = test_sources[0]
                            sub_label = test_sources[1]
                            sub_sources, sub_labels = prepare_batch(
                                sub_source_seq,
                                sub_label,
                                max_batch=config['max_batch'],
                                maxlen=config['maxlen'],
                                stride=config['stride'],
                                batch_size=config['batch_size'])
                            for sub_source, sub_label in zip(
                                    sub_sources, sub_labels):
                                pred, logit, acc, loss = model.predict(
                                    sess, sub_source, sub_label)
                                #print("step {}, size {}, acc {:g}, softmax_loss {:g}".format(model.global_step.eval(), pred.shape, acc, loss))
                                _acc += acc * pred.shape[0]
                                _loss += loss * pred.shape[0]
                                _num += pred.shape[0]
                        print("acc {:g}, softmax_loss {:g}".format(
                            _acc / _num, _loss / _num))
                        #print("step {}, acc {:g}, softmax_loss {:g}".format(model.global_step.eval(), _acc/_num, _loss/_num))
                        if _acc / _num > maxs:
                            save(sess, config, model)
                            maxs = _acc / _num