Example #1
0
def predict(model, args):
    try:
        torch.set_grad_enabled(False)
    except AttributeError:
        pass
    logging.info('model: %s, setup: %s' %
                 (type(model).__name__, str(model.args)))
    logging.info('loading dataset')

    if args.snapshot is None:
        epoch = load_last_snapshot(model, args.workspace)
    else:
        epoch = args.snapshot
        load_snapshot(model, args.workspace, epoch)
    logging.info('loaded model at epoch %s', str(epoch))

    to_categorical = Categorical('</s>')
    to_categorical.load_dict(model.words)
    trans = to_categorical(Words(':', null='</s>'))

    while True:
        # loop over inputs
        try:
            line = input()
        except EOFError:
            logging.info('bye')
            break

        try:
            obj = json.loads(line, encoding='utf-8')
            ref_seq = obj['ref']
            pred_seq = obj['pred']
        except (json.decoder.JSONDecodeError, KeyError):
            print('[]')
            continue

        h = None
        for i, item in enumerate(ref_seq):
            x = trans.apply(None, item['fea'])
            x = Variable(torch.LongTensor(x), volatile=True)
            score = Variable(torch.FloatTensor([item['t']]), volatile=True)
            t = Variable(torch.FloatTensor([item['s']]), volatile=True)
            _, h = model(x, score, t, h)

        pred_scores = []

        for i, item in enumerate(pred_seq):
            x = trans.apply(None, item['fea'])
            x = Variable(torch.LongTensor(x), volatile=True)
            score = Variable(torch.FloatTensor([0.]), volatile=True)
            t = Variable(torch.FloatTensor([item['t']]), volatile=True)
            s, _ = model(x, score, t, h)
            pred_scores.append(s.cpu().data[0][0])

        print(pred_scores)
Example #2
0
    def infer_stack(args):
        tmp = Image.open(args.stack)
        h,w = np.shape(tmp)
        N = tmp.n_frames

        imgs = np.zeros((N, 3, h, w))
        for i in range(N):
            tmp.seek(i)
            imgs[i, 0, :, :] = np.array(tmp)
            imgs[i, 1, :, :] = np.array(tmp)
            imgs[i, 2, :, :] = np.array(tmp)
        imgs = imgs.astype("float32")
        imgs = imgs / 255.0 - 0.5

        tfutil.init_tf(tf_config)
        net = util.load_snapshot(args.network)

        res = np.empty((N, h, w), dtype="uint16")
        for i in range(N):
            res[i,:,:] = util.infer_image_pp(net, imgs[i,:,:,:])

        #tmp = Image.fromarray(res[0,:,:,:].transpose([1,2,0]).astype("uint8"))
        tmp = Image.fromarray(res[0,:,:])
        tmp.save(args.out, format="tiff",
                 append_images=[Image.fromarray(res[i,:,:]) for i in range(1, res.shape[0])],
                 save_all=True)
def infer_image(network_snapshot: str, image: str, out_image: str):
    tfutil.init_tf(config.tf_config)
    net = util.load_snapshot(network_snapshot)
    im = PIL.Image.open(image).convert('RGB')
    arr = np.array(im, dtype=np.float32)
    reshaped = arr.transpose([2, 0, 1]) / 255.0 - 0.5
    pred255 = util.infer_image(net, reshaped)
    t = pred255.transpose([1, 2, 0])  # [RGB, H, W] -> [H, W, RGB]
    PIL.Image.fromarray(t, 'RGB').save(os.path.join(out_image))
    print('Inferred image saved in', out_image)
Example #4
0
def validate(submit_config: dnnlib.SubmitConfig, noise: dict, dataset: dict, network_snapshot: str):
    noise_augmenter = dnnlib.util.call_func_by_name(**noise)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**dataset)

    ctx = dnnlib.RunContext(submit_config, config)

    tfutil.init_tf(config.tf_config)

    with tf.device("/gpu:0"):
        net = util.load_snapshot(network_snapshot)
        validation_set.evaluate(net, 0, noise_augmenter.add_validation_noise_np)
    ctx.close()
Example #5
0
def validate(submit_config: submit.SubmitConfig, tf_config: dict, noise: dict,
             dataset: dict, network_snapshot: str):
    noise_augmenter = noise.func(**noise.func_kwargs)
    validation_set = ValidationSet(submit_config)
    validation_set.load(**dataset)

    ctx = RunContext(submit_config)

    tfutil.init_tf(tf_config)

    with tf.device("/gpu:0"):
        net = util.load_snapshot(network_snapshot)
        validation_set.evaluate(net, 0,
                                noise_augmenter.add_validation_noise_np)
    ctx.close()
Example #6
0
parser.add_argument("network_dir", help="path to network directory")
parser.add_argument("tf_train", help="path to tf train datasets")

#network_dir = "/home/pierre/cam/denoising/noise2noise/results/00024-autoencoder"
#tf_train = "/mnt/data/denoising_data/tf_pa_train_prepost.tf"

args = parser.parse_args()

if tf.get_default_session() is None:
    session = tf.Session(config=tf.ConfigProto())
    session._default_session = session.as_default()
    session._default_session.enforce_nesting = False
    session._default_session.__enter__() # pylint: disable=no-member

net = util.load_snapshot(args.network_dir + "/network_169000.pickle")

reader = tf.TFRecordReader()

feats = {'shape': tf.FixedLenFeature([3], tf.int64),
         'data1': tf.FixedLenFeature([], tf.string),
         'data2': tf.FixedLenFeature([], tf.string)}

def _parse_image_function(example_proto):
  return tf.parse_single_example(example_proto, feats)

raw_image_dataset = tf.data.TFRecordDataset(args.tf_train)
dataset = raw_image_dataset.map(_parse_image_function)
dat = dataset.make_one_shot_iterator().get_next()
print(dat)
assert(False)
Example #7
0
def testseq(model, args):
    try:
        torch.set_grad_enabled(False)
    except AttributeError:
        pass
    logging.info('model: %s, setup: %s' %
                 (type(model).__name__, str(model.args)))
    logging.info('loading dataset')

    data = get_dataset(args.dataset)
    data.random_level = args.random_level

    if not args.dataset.endswith('test'):
        if args.split_method == 'user':
            _, data = data.split_user(args.frac)
            testsets = [('user_split', data, {})]
        elif args.split_method == 'future':
            _, data = data.split_future(args.frac)
            testsets = [('future_split', data, {})]
        elif args.split_method == 'old':
            trainset, _, _, _ = data.split()
            data = trainset.get_seq()
            train, user, exam, new = data.split()
            train = train.get_seq()
            user = user.get_seq()
            exam = exam.get_seq()
            new = new.get_seq()
            testsets = zip(['user', 'exam', 'new'], [user, exam, new],
                           [{}, train, user])
        else:
            if args.ref_set:
                ref = get_dataset(args.ref_set)
                ref.random_level = args.random_level
                testsets = [(args.dataset.split('/')[-1], data.get_seq(),
                             ref.get_seq())]
            else:
                testsets = [('student', data.get_seq(), {})]
    else:
        testsets = [('school', data.get_seq(), {})]

    if args.input_knowledge:
        logging.info('loading knowledge concepts')
        topic_dic = {}
        kcat = Categorical(one_hot=True)
        kcat.load_dict(open(model.args['knows']).read().split('\n'))
        know = 'data/id_firstknow.txt' if 'first' in model.args['knows'] \
            else 'data/id_know.txt'
        for line in open(know):
            uuid, know = line.strip().split(' ')
            know = know.split(',')
            topic_dic[uuid] = torch.LongTensor(kcat.apply(None,
                                                          know)).max(0)[0]
        zero = [0] * len(kcat.apply(None, '<NULL>'))

    if args.input_text:
        logging.info('loading exercise texts')
        topics = get_topics(args.dataset, model.words)

    if args.snapshot is None:
        epoch = load_last_snapshot(model, args.workspace)
    else:
        epoch = args.snapshot
        load_snapshot(model, args.workspace, epoch)
    logging.info('loaded model at epoch %s', str(epoch))

    if use_cuda:
        model.cuda()

    for testset, data, ref_data in testsets:
        logging.info('testing on: %s', testset)
        f = open_result(args.workspace, testset, epoch)

        then = time.time()

        total_mse = 0
        total_mae = 0
        total_acc = 0
        total_seq_cnt = 0

        users = list(data)
        random.shuffle(users)
        seq_cnt = len(users)

        MSE = torch.nn.MSELoss()
        MAE = torch.nn.L1Loss()

        for user in users[:5000]:
            total_seq_cnt += 1

            seq = data[user]
            if user in ref_data:
                ref_seq = ref_data[user]
            else:
                ref_seq = []

            length = len(seq)
            ref_len = len(ref_seq)
            seq = ref_seq + seq

            if ref_len < args.ref_len:
                length = length + ref_len - args.ref_len
                ref_len = args.ref_len

            if length < 1:
                ref_len = ref_len + length - 1
                length = 1

            mse = 0
            mae = 0
            acc = 0

            # seq2 = []
            # seen = set()
            # for item in seq:
            #     if item.topic in seen:
            #         continue
            #     seen.add(item.topic)
            #     seq2.append(item)

            # seq = seq2
            # length = len(seq) - ref_len

            pred_scores = Variable(torch.zeros(len(seq)))

            s = None
            h = None

            for i, item in enumerate(seq):
                # get last record for testing and current record for updating
                if args.input_knowledge:
                    if item.topic in topic_dic:
                        knowledge = topic_dic[item.topic]
                        knowledge_last = topic_dic[seq[-1].topic]
                    else:
                        knowledge = zero
                        knowledge_last = zero
                    knowledge = Variable(torch.LongTensor(knowledge))
                    knowledge_last = Variable(torch.LongTensor(knowledge_last),
                                              volatile=True)

                if args.input_text:
                    text = topics.get(item.topic).content
                    text = Variable(torch.LongTensor(text))
                    text_last = topics.get(seq[-1].topic).content
                    text_last = Variable(torch.LongTensor(text_last),
                                         volatile=True)

                score = Variable(torch.FloatTensor([item.score]),
                                 volatile=True)
                score_last = Variable(torch.FloatTensor([round(seq[-1].score)
                                                         ]),
                                      volatile=True)
                item_time = Variable(torch.FloatTensor([item.time]),
                                     volatile=True)
                time_last = Variable(torch.FloatTensor([seq[-1].time]),
                                     volatile=True)

                # test last score of each seq for seq figure
                if type(model).__name__.startswith('DK'):
                    s, _ = model(knowledge_last, score_last, time_last, h)
                elif type(model).__name__.startswith('RA'):
                    s, _ = model(text_last, score_last, time_last, h)
                elif type(model).__name__.startswith('EK'):
                    s, _ = model(text_last, knowledge_last, score_last,
                                 time_last, h)
                s_last = torch.clamp(s, 0, 1)

                # update student state h until the fit process reaches trainset
                if ref_len > 0 and i > ref_len:
                    if type(model).__name__.startswith('DK'):
                        s, _ = model(knowledge, score, item_time, h)
                    elif type(model).__name__.startswith('RA'):
                        s, _ = model(text, score, item_time, h)
                    elif type(model).__name__.startswith('EK'):
                        s, _ = model(text, knowledge, score, item_time, h)
                else:
                    if type(model).__name__.startswith('DK'):
                        s, h = model(knowledge, score, item_time, h)
                    elif type(model).__name__.startswith('RA'):
                        s, h = model(text, score, item_time, h)
                    elif type(model).__name__.startswith('EK'):
                        s, h = model(text, knowledge, score, item_time, h)

                pred_scores[i] = s_last

                if args.loss == 'cross_entropy':
                    s = F.sigmoid(s)
                else:
                    s = torch.clamp(s, 0, 1)
                if i < ref_len:
                    continue

                mse += MSE(s, score)
                m = MAE(s, score).data[0]
                mae += m
                acc += m < 0.5

            print_seq(seq, pred_scores.data.cpu().numpy(), ref_len, f, True)

            mse /= length
            mae /= length
            acc = float(acc) / length

            total_mse += mse.data[0]
            total_mae += mae
            total_acc += acc

            if total_seq_cnt % args.print_every != 0 and total_seq_cnt != seq_cnt:
                continue

            now = time.time()
            duration = (now - then) / 60

            logging.info(
                '[%d/%d] (%.2f seqs/min) '
                'rmse %.6f, mae %.6f, acc %.6f' %
                (total_seq_cnt, seq_cnt,
                 ((total_seq_cnt - 1) % args.print_every + 1) / duration,
                 math.sqrt(total_mse / total_seq_cnt),
                 total_mae / total_seq_cnt, total_acc / total_seq_cnt))
            then = now
        f.close()
Example #8
0
def test(model, args):
    try:
        torch.set_grad_enabled(False)
    except AttributeError:
        pass
    logging.info('model: %s, setup: %s' %
                 (type(model).__name__, str(model.args)))
    logging.info('loading dataset')
    data = get_dataset(args.dataset)
    data.random_level = args.random_level

    if not args.dataset.endswith('test'):
        if args.split_method == 'user':
            _, data = data.split_user(args.frac)
            testsets = [('user_split', data, {})]
        elif args.split_method == 'future':
            _, data = data.split_future(args.frac)
            testsets = [('future_split', data, {})]
        elif args.split_method == 'old':
            trainset, _, _, _ = data.split()
            data = trainset.get_seq()
            train, user, exam, new = data.split()
            train = train.get_seq()
            user = user.get_seq()
            exam = exam.get_seq()
            new = new.get_seq()
            testsets = zip(['user', 'exam', 'new'], [user, exam, new],
                           [{}, train, user])
        else:
            if args.ref_set:
                ref = get_dataset(args.ref_set)
                ref.random_level = args.random_level
                testsets = [(args.dataset.split('/')[-1], data.get_seq(),
                             ref.get_seq())]
            else:
                testsets = [('student', data.get_seq(), {})]
    else:
        testsets = [('school', data.get_seq(), {})]

    if type(model).__name__.startswith('DK'):
        topic_dic = {}
        kcat = Categorical(one_hot=True)
        kcat.load_dict(open('data/know_list.txt').read().split('\n'))
        for line in open('data/id_know.txt'):
            uuid, know = line.strip().split(' ')
            know = know.split(',')
            topic_dic[uuid] = \
                torch.LongTensor(kcat.apply(None, know)) \
                .max(0)[0] \
                .type(torch.LongTensor)
        zero = [0] * len(kcat.apply(None, '<NULL>'))
    else:
        topics = get_topics(args.dataset, model.words)

    if args.snapshot is None:
        epoch = load_last_snapshot(model, args.workspace)
    else:
        epoch = args.snapshot
        load_snapshot(model, args.workspace, epoch)
    logging.info('loaded model at epoch %s', str(epoch))

    if use_cuda:
        model.cuda()

    for testset, data, ref_data in testsets:
        logging.info('testing on: %s', testset)
        f = open_result(args.workspace, testset, epoch)

        then = time.time()

        total_mse = 0
        total_mae = 0
        total_acc = 0
        total_seq_cnt = 0

        users = list(data)
        random.shuffle(users)
        seq_cnt = len(users)

        MSE = torch.nn.MSELoss()
        MAE = torch.nn.L1Loss()

        for user in users[:5000]:
            seq = data[user]
            if user in ref_data:
                ref_seq = ref_data[user]
            else:
                ref_seq = []

            seq2 = []
            seen = set()
            for item in ref_seq:
                if item.topic in seen:
                    continue
                seen.add(item.topic)
                seq2.append(item)
            ref_seq = seq2

            seq2 = []
            for item in seq:
                if item.topic in seen:
                    continue
                seen.add(item.topic)
                seq2.append(item)
            seq = seq2

            ref_len = len(ref_seq)
            seq = ref_seq + seq
            length = len(seq)

            if ref_len < args.ref_len:
                length = length + ref_len - args.ref_len
                ref_len = args.ref_len

            if length < 1:
                continue
            total_seq_cnt += 1

            mse = 0
            mae = 0
            acc = 0

            pred_scores = Variable(torch.zeros(len(seq)))

            s = None
            h = None

            for i, item in enumerate(seq):
                if args.test_on_last:
                    x = topics.get(seq[-1].topic).content
                    x = Variable(torch.LongTensor(x), volatile=True)
                    score = Variable(torch.FloatTensor([round(seq[-1].score)]),
                                     volatile=True)
                    t = Variable(torch.FloatTensor([seq[-1].time]),
                                 volatile=True)
                    s, _ = model(x, score, t, h)
                    s_last = torch.clamp(s, 0, 1)
                if type(model).__name__.startswith('DK'):
                    if item.topic in topic_dic:
                        x = topic_dic[item.topic]
                    else:
                        x = zero
                else:
                    x = topics.get(item.topic).content
                x = Variable(torch.LongTensor(x))
                score = Variable(torch.FloatTensor([round(item.score)]),
                                 volatile=True)
                t = Variable(torch.FloatTensor([item.time]), volatile=True)
                if args.test_as_seq and i > ref_len and ref_len > 0:
                    s, h = model(x, s.view(1), t, h)
                else:
                    if ref_len > 0 and i > ref_len and not args.test_on_one:
                        s, _ = model(x, score, t, h)
                    else:
                        s, h = model(x, score, t, h)
                if args.loss == 'cross_entropy':
                    s = F.sigmoid(s)
                else:
                    s = torch.clamp(s, 0, 1)
                if args.test_on_last:
                    pred_scores[i] = s_last
                else:
                    pred_scores[i] = s
                if i < ref_len:
                    continue
                mse += MSE(s, score)
                m = MAE(s, score).data[0]
                mae += m
                acc += m < 0.5

            print_seq(seq,
                      pred_scores.data.cpu().numpy(), ref_len, f,
                      args.test_on_last)

            mse /= length
            mae /= length
            acc /= length

            total_mse += mse.data[0]
            total_mae += mae
            total_acc += acc

            if total_seq_cnt % args.print_every != 0 and \
                    total_seq_cnt != seq_cnt:
                continue

            now = time.time()
            duration = (now - then) / 60

            logging.info(
                '[%d/%d] (%.2f seqs/min) '
                'rmse %.6f, mae %.6f, acc %.6f' %
                (total_seq_cnt, seq_cnt,
                 ((total_seq_cnt - 1) % args.print_every + 1) / duration,
                 math.sqrt(total_mse / total_seq_cnt),
                 total_mae / total_seq_cnt, total_acc / total_seq_cnt))
            then = now

        f.close()