Ejemplo n.º 1
0
def main(args):
    model = Agent('agent', 3, 1)
    ph.initialize_global_variables()
    model.init()

    render = False
    env = gym.make('Pendulum-v0')
    env_noise = deep_rl.NormalNoise(0.5)
    for i in range(args.num_loops):
        total_r = 0
        s = env.reset()
        for t in range(1000):
            if render:
                env.render()

            a, = model.predict([s])[0]
            a = env_noise.add_noise(a)
            a *= env.action_space.high
            env_noise.discount(0.9999)

            s_, r, done, info = env.step(a)

            model.feedback(s, a, r, s_)
            model.train(args.batch_size)

            total_r += r
            s = s_
            if done:
                print('[%d] %f' % (i, total_r))
                if total_r > -300 and i >= 100:
                    render = True
                break
    return 0
Ejemplo n.º 2
0
def main1(args):
    log_file = 'logs_' + args.name
    model_dir = 'models_' + args.name
    exp = common.Experiment(log_file, model_dir)
    we = ph.utils.WordEmbedding()  # 修改了原函数 初始化之后可以直接通过embeding查询
    #
    trainer = Main('amazon_rnn', 300)  # 词向量 300维
    ph.initialize_global_variables()
    #
    test_list = []

    for i in range(domain_num):
        print('***********************')
        print('domain:', i)
        vocab, train_data, dev_data, test_data = build_dataset_LL(i)  # 重写的数据读取
        train_ds = common.TrainSource(train_data, we, i)  # 实现了其中的DataSource
        dev_ds = common.TrainSource(dev_data, we, i)
        test_ds = common.TrainSource(test_data, we, i)
        test_list.append(test_ds)
        #
        exp.load_model(trainer)
        seq_stat = trainer.stat.read_stat(trainer.flat_seq)
        states_stat = trainer.stat.read_stat(trainer.flat_states)
        trainer._optimizer.update_mask(trainer.shared.cell.wz, seq_stat, i)
        trainer._optimizer.update_mask(trainer.shared.cell.wr, seq_stat, i)
        trainer._optimizer.update_mask(trainer.shared.cell.wh, seq_stat, i)
        trainer._optimizer.update_mask(trainer.shared.cell.uz, states_stat, i)
        trainer._optimizer.update_mask(trainer.shared.cell.ur, states_stat, i)
        trainer._optimizer.update_mask(trainer.shared.cell.uh, states_stat, i)
        trainer.add_data_trainer(train_ds, 32)  # 32 batch_size
        # trainer.add_screen_logger('train', ('Loss', 'Norm'), interval=1)
        trainer.add_data_validator(test_ds, 32, interval=20)
        # trainer.add_screen_logger(
        #     "validate",
        #     ('hit_pos', 'hit_neg', 'pred_pos', 'pred_neg', 'Error'),
        #     message='[%d]' % i,
        #     interval=20
        # )
        trainer.add_fitter(common.DevFitter(dev_ds, 32, 20))
        trainer.fit(args.num_loops)
        trainer.clear_fitters()
        #
        exp.dump_model(trainer)

        # test turn for bwt
        for test_data in test_list:
            trainer.add_data_validator(test_data, 32, interval=1)
            trainer.add_screen_logger(
                "validate",
                ('hit_pos', 'hit_neg', 'pred_pos', 'pred_neg', 'Error'),
                message='[%d]' % i,
                interval=1)
            trainer.fit(1)
            trainer.clear_fitters()

        trainer.stat.update_stats()
    return 0
Ejemplo n.º 3
0
    def _main(self, args):
        ph.set_tf_log_level(ph.TF_LOG_NO_WARN)
        with pymongo.MongoClient('sis2.ustcdm.org') as conn:
            conn['admin'].authenticate('root', 'SELECT * FROM users;')

            coll_train = conn['imagenet_vgg']['train']
            # coll_valid = conn['imagenet_vgg']['valid']

            mapping = utils.IndexMapping(conn['imagenet_deepme']['clusters'])

            ds_train = DataSource(coll_train, mapping, True, args.batch_size)
            # ds_valid = DataSource(coll_valid, mapping, False, args.batch_size)

            model = Model('model', args.keep_prob)
            ph.initialize_global_variables()
            ph.io.load_model_from_file(model['encoder'], args.vgg16, 'vgg16')

            #
            # train the last layer
            progress = tqdm(total=args.num_train, ncols=96, desc='Training')
            for i in range(args.num_train):
                self.checkpoint()
                try:
                    _, image, target = ds_train.next()
                except StopIteration:
                    _, image, target = ds_train.next()
                loss, = model.train(image, target)
                progress.set_description(f'Training loss={loss:.06f}',
                                         refresh=False)
                progress.update()
            progress.close()

            #
            # fine tuning all the parameters
            progress = tqdm(total=args.num_loops, ncols=96, desc='Fine tuning')
            for i in range(args.num_loops):
                self.checkpoint()
                try:
                    _, image, target = ds_train.next()
                except StopIteration:
                    _, image, target = ds_train.next()
                loss, = model.fine_tune(image, target)
                progress.set_description(f'Fine tuning loss={loss:.06f}',
                                         refresh=False)
                progress.update()
            progress.close()
            ds_train = None
            ds_valid = None

            if args.write_results:
                coll = conn['imagenet_vgg']['test']
                coll_output = conn['imagenet_deepme']['gate_test']
                self._write_result(model, coll, mapping, coll_output)

        print('All clear.')
        return 0
Ejemplo n.º 4
0
def main1(args):
    log_file = 'logs_' + args.name
    model_dir = 'models_' + args.name
    exp = common.Experiment(log_file, model_dir)
    with pymongo.MongoClient('localhost:27017') as client:
        client['admin'].authenticate('root', 'SELECT * FROM password;')
        we = ph.utils.WordEmbedding(client['word2vec']['glove_twitter'])
        #
        trainer = Main('amazon_rnn', 200)
        ph.initialize_global_variables()
        #
        for i, domain in enumerate(DOMAINS):
            print('domain:', domain)
            train_ds = common.TrainSource(client['ayasa']['train1'], we,
                                          domain)
            dev_ds = common.DevSource(client['ayasa']['train1'], we, domain)
            test_ds = common.TestSource(client['ayasa']['test1'], we, domain)
            #
            exp.load_model(trainer)
            seq_stat = trainer.stat.read_stat(trainer.flat_seq)
            states_stat = trainer.stat.read_stat(trainer.flat_states)
            trainer._optimizer.update_mask(trainer.shared.cell.wz, seq_stat, i)
            trainer._optimizer.update_mask(trainer.shared.cell.wr, seq_stat, i)
            trainer._optimizer.update_mask(trainer.shared.cell.wh, seq_stat, i)
            trainer._optimizer.update_mask(trainer.shared.cell.uz, states_stat,
                                           i)
            trainer._optimizer.update_mask(trainer.shared.cell.ur, states_stat,
                                           i)
            trainer._optimizer.update_mask(trainer.shared.cell.uh, states_stat,
                                           i)
            trainer.add_data_trainer(train_ds, args.batch_size)
            trainer.add_screen_logger(ph.CONTEXT_TRAIN, ('Loss', 'Norm'),
                                      interval=1)
            trainer.add_data_validator(test_ds, 32, interval=20)
            trainer.add_screen_logger(
                ph.CONTEXT_VALID,
                ('hit_pos', 'hit_neg', 'pred_pos', 'pred_neg', 'Error'),
                message='[%d]' % domain,
                interval=20)
            trainer.add_fitter(common.DevFitter(dev_ds, args.batch_size, 20))
            trainer.fit(args.num_loops)
            trainer.clear_fitters()
            #
            # exp.dump_model(trainer, 'domain' + str(domain))
            trainer.add_data_validator(test_ds, 32, interval=1)
            trainer.add_screen_logger(
                ph.CONTEXT_VALID,
                ('hit_pos', 'hit_neg', 'pred_pos', 'pred_neg', 'Error'),
                message='[%d]' % domain,
                interval=1)
            trainer.fit(1)
            trainer.clear_fitters()
            trainer.stat.update_stats()
    return 0
Ejemplo n.º 5
0
def main(args):
    dataset_name = 'mnist'
    with pymongo.MongoClient('sis2.ustcdm.org') as conn:
        db = conn['images']
        coll_train = db['%s_train' % dataset_name]
        coll_test = db['%s_test' % dataset_name]

        label_set = set()
        label_set.update(coll_train.distinct('label_index'))
        label_set.update(coll_test.distinct('label_index'))
        num_classes = len(label_set)

        ds_train = DataSource(coll_train, num_classes, True)
        ds_train = ph.io.BatchSource(ds_train, args.batch_size)
        ds_train = ph.io.ThreadBufferedSource(ds_train, 100)
        ds_test = DataSource(coll_test, num_classes, False)
        ds_test = ph.io.BatchSource(ds_test, args.batch_size)

        model = ram.RAM('ram',
                        12,
                        12,
                        1,
                        3,
                        128,
                        128,
                        256,
                        state_size=256,
                        num_classes=num_classes,
                        num_steps=6,
                        stddev=0.1)
        ph.initialize_global_variables()

        for i in range(1, args.num_loops + 1):
            batch = ds_train.next()
            if batch is None:
                batch = ds_train.next()
            image, label = batch

            loss, reward = model.train(image, label)
            if i % 50 == 0:
                print(i, 'loss:', loss, 'reward', reward)

            if i % 100 == 0:
                cal = ph.train.AccCalculator()
                for image, label in ds_test:
                    label_pred, = model.predict(image)
                    cal.update(label_pred, label)
                acc = cal.accuracy
                print()
                print('acc:', acc * 100)
                print()

    return 0
Ejemplo n.º 6
0
    def _main(self, args):
        with pymongo.MongoClient('sis4.ustcdm.org') as conn:
            conn['admin'].authenticate('root', 'SELECT * FROM users;')
            db = conn['imagenet_dmde']
            coll_train = db['fusion_train']
            coll_valid = db['fusion_valid']
            coll_test = db['fusion_test']

            ds_train = DataSource(coll_train, True, args.batch_size)
            ds_valid = DataSource(coll_valid, False, args.batch_size)

            model = Model('model', 10184, 10184, args.keep_prob)
            ph.initialize_global_variables()

            progress = tqdm(total=args.num_loops, ncols=96, desc='Training')
            monitor = ph.train.EarlyStopping(5, model)
            for i in range(args.num_loops):
                self.checkpoint()
                try:
                    _, x, label = ds_train.next()
                except StopIteration:
                    _, x, label = ds_train.next()
                loss, = model.train(x, label)
                progress.set_description(f'Training loss={loss:.06f}',
                                         refresh=False)

                if (i + 1) % 1000 == 0:
                    progress_valid = tqdm(total=coll_valid.count(),
                                          ncols=96,
                                          desc='Validating')
                    cal = ph.train.AccCalculator()
                    for _, x, label in ds_valid:
                        label_pred, _ = model.predict(x)
                        cal.update(label_pred, label)
                        progress_valid.update(len(x))
                    progress_valid.close()
                    progress.clear()
                    print(f'[{i + 1}] Validation acc={cal.accuracy}')
                    if monitor.convergent(1 - cal.accuracy):
                        model.set_parameters(monitor.best_parameters)
                        break

                progress.update()
            progress.close()

            if args.write_results:
                coll_output = conn['imagenet_deepme']['final_dmde']
                self._write_result(model, coll_test, coll_output)

        return 0
Ejemplo n.º 7
0
def main(args):
    ph.set_tf_log_level(ph.TF_LOG_NO_WARN)
    with pymongo.MongoClient('sis3.ustcdm.org') as conn:
        conn['admin'].authenticate('root', 'SELECT * FROM users;')

        db = conn['imagenet_vgg']
        coll = db[args.coll]
        coll_output = db[f'h7_{args.coll}']
        coll_output.create_index('label_index')

        ds_test = DataSource(coll, False, args.batch_size)

        model = Model('model')
        ph.initialize_global_variables()
        ph.io.load_model_from_file(model['encoder'], args.vgg16, 'vgg16')

        progress = tqdm(total=coll.count(), ncols=96)
        buffer = []
        for _id, image, label in ds_test:
            h7, = model.predict(image)
            for _id_i, h7_i, label_i in zip(_id, h7, label):
                buffer.append({
                    '_id': _id_i,
                    'h7': gzip.compress(pickle.dumps(h7_i), 5),
                    'label_index': label_i
                })
                if len(buffer) >= 1000:
                    coll_output.insert_many(buffer)
                    buffer.clear()
            progress.update(len(image))
        if len(buffer) != 0:
            coll_output.insert_many(buffer)
            buffer.clear()
        progress.close()

    print('All clear.')
    return 0
Ejemplo n.º 8
0
    def _main(self, args):
        mnist_data = mnist.input_data.read_data_sets('.', one_hot=False)
        ds_train = ph.io.BatchSource(DataSource(mnist_data.train),
                                     args.batch_size)
        ds_valid = ph.io.BatchSource(DataSource(mnist_data.validation),
                                     args.batch_size)
        ds_test = ph.io.BatchSource(DataSource(mnist_data.test),
                                    args.batch_size)

        model = Model('mnist_mlp', 1000)
        ph.initialize_global_variables()

        for i in range(1, args.num_loops + 1):
            self.checkpoint()
            try:
                image, label = ds_train.next()
            except StopIteration:
                image, label = ds_train.next()
            loss, = model.train(image, label)
            if i % 100 == 0:
                print(
                    f'Training [{i}/{args.num_loops}|{i / args.num_loops * 100:.02f}%]... loss={loss:.06f}'
                )

            if i % 500 == 0:
                acc = ph.train.AccCalculator()
                for image, label in ds_valid:
                    label_pred, = model.predict(image)
                    acc.update(label_pred, label)
                print(f'Validation acc={acc.accuracy * 100}%')
                acc = ph.train.AccCalculator()
                for image, label in ds_test:
                    label_pred, = model.predict(image)
                    acc.update(label_pred, label)
                print(f'Test acc={acc.accuracy * 100}%')
        return 0
Ejemplo n.º 9
0
    def _main(self, args):
        ph.set_tf_log_level(ph.TF_LOG_NO_WARN)
        with pymongo.MongoClient('sis3.ustcdm.org') as conn:
            conn['admin'].authenticate('root', 'SELECT * FROM users;')

            ################################################################################
            # define data source and init model
            ################################################################################
            db = conn['imagenet_dmde']
            coll_train = db[f'task_{args.task_index:02d}_train']
            coll_valid = db[f'task_{args.task_index:02d}_valid']
            num_classes = len(coll_train.distinct('label_index'))
            print(f'Found {num_classes} classes.')

            ds_train = DataSource(coll_train, True, args.batch_size)
            ds_valid = DataSource(coll_valid, False, args.batch_size)

            model = Model('model',
                          hidden_size=args.hidden_size,
                          num_classes=num_classes,
                          keep_prob=args.keep_prob,
                          reg=args.reg,
                          grad_clip=args.grad_clip,
                          learning_rate_1=args.learning_rate_1,
                          learning_rate_2=args.learning_rate_2,
                          num_loops_1=args.num_loops_1,
                          num_loops_2=args.num_loops_2)
            ph.initialize_global_variables()
            ph.io.load_model_from_file(model['encoder'], args.alexnet,
                                       'alexnet')

            ################################################################################
            # pre-train
            ################################################################################
            progress = tqdm(total=args.num_loops_1, ncols=96)
            for i in range(args.num_loops_1):
                self.checkpoint()
                try:
                    _, image, label = ds_train.next()
                except StopIteration:
                    _, image, label = ds_train.next()
                loss, lr = model.train(image, label)
                progress.set_description(
                    f'Pre-train loss={loss:.03e}, lr={lr:.03e}', refresh=False)
                progress.update()
            progress.close()

            progress = tqdm(total=coll_valid.count(),
                            ncols=96,
                            desc='Validating')
            cal = ph.train.AccCalculator()
            for _, image, label in ds_valid:
                label_pred, _ = model.predict(image)
                cal.update(label_pred, label)
                progress.update(len(image))
            progress.close()
            print(f'Validation acc={cal.accuracy}')

            ################################################################################
            # fine tune
            ################################################################################
            progress = tqdm(total=args.num_loops_2, ncols=96)
            monitor = ph.train.EarlyStopping(5, model)
            for i in range(args.num_loops_2):
                self.checkpoint()
                try:
                    _, image, label = ds_train.next()
                except StopIteration:
                    _, image, label = ds_train.next()
                loss, lr = model.fine_tune(image, label)
                progress.set_description(
                    f'Fine tune loss={loss:.03e}, lr={lr:.03e}', refresh=False)

                if (i + 1) % 1000 == 0:
                    progress_valid = tqdm(total=coll_valid.count(),
                                          ncols=96,
                                          desc='Validating')
                    cal = ph.train.AccCalculator()
                    for _, image, label in ds_valid:
                        label_pred, _ = model.predict(image)
                        cal.update(label_pred, label)
                        progress_valid.update(len(image))
                    progress_valid.close()
                    progress.clear()
                    print(f'[{i + 1}] Validation acc={cal.accuracy}')
                    if monitor.convergent(1 - cal.accuracy):
                        model.set_parameters(monitor.best_parameters)
                        break
                progress.update()
            progress.close()
            ds_train = None
            ds_valid = None

            ################################################################################
            # write results
            ################################################################################
            if args.write_results:
                coll = conn['imagenet']['imagenet_10k_224_train']
                coll_output = db[f'alexnet_result_{args.task_index:02d}_train']
                self._write_result(model, coll, coll_output)

                coll = conn['imagenet']['imagenet_10k_224_valid']
                coll_output = db[f'alexnet_result_{args.task_index:02d}_valid']
                self._write_result(model, coll, coll_output)

                coll = conn['imagenet']['imagenet_10k_224_test']
                coll_output = db[f'alexnet_result_{args.task_index:02d}_test']
                self._write_result(model, coll, coll_output)

        print('All clear.')
        return 0
Ejemplo n.º 10
0
    def _main(self, args):
        # ph.get_session_config().log_device_placement = True
        ph.set_tf_log_level(ph.TF_LOG_NO_WARN)
        with pymongo.MongoClient('sis4.ustcdm.org') as conn:
            conn['admin'].authenticate('root', 'SELECT * FROM users;')

            db = conn['imagenet_vgg']
            coll_train = db[f'train']
            coll_valid = db[f'valid']
            num_classes = len(coll_train.distinct('label_index'))
            print(f'Found {num_classes} classes.')

            ds_train = DataSource(coll_train, True, args.batch_size)
            ds_valid = DataSource(coll_valid, False, args.batch_size)

            model = Model('model', num_classes, len(args.gpu.split(',')),
                          args.batch_size)
            ph.initialize_global_variables()
            ph.io.load_model_from_file(model['encoder'], args.vgg16, 'vgg16')

            #
            # train the last layer
            bar = tqdm(total=args.num_train, ncols=96, desc='Training')
            for i in range(args.num_train):
                self.checkpoint()
                try:
                    _, image, label = ds_train.next()
                except StopIteration:
                    _, image, label = ds_train.next()
                loss, = model.train(image, label)
                bar.update()
                bar.set_description(f'Training loss={loss:.06f}')
            bar.close()

            #
            # validation
            bar = tqdm(total=coll_valid.count(), ncols=96, desc='Validating')
            cal = ph.train.AccCalculator()
            for _, image, label in ds_valid:
                label_pred, _ = model.predict(image)
                cal.update(label_pred, label)
                bar.update(len(image))
            bar.close()
            print(f'Validation acc={cal.accuracy}')

            #
            # fine tuning all the parameters
            bar = tqdm(total=args.num_loops, ncols=96, desc='Fine tuning')
            es = ph.train.EarlyStopping(3)
            for i in range(args.num_loops):
                self.checkpoint()
                try:
                    _, image, label = ds_train.next()
                except StopIteration:
                    _, image, label = ds_train.next()
                loss, = model.fine_tune(image, label)
                bar.set_description(f'Fine tuning loss={loss:.06f}')

                if (i + 1) % 10000 == 0:
                    bar1 = tqdm(total=coll_valid.count(),
                                ncols=96,
                                desc='Validating')
                    cal = ph.train.AccCalculator()
                    for _, image, label in ds_valid:
                        label_pred, _ = model.predict(image)
                        cal.update(label_pred, label)
                        bar1.update(len(image))
                    bar1.close()
                    bar.clear()
                    print(f'Validation acc={cal.accuracy}')
                    if es.convergent(1 - cal.accuracy):
                        break
                bar.update()
            bar.close()

            coll_test = db[f'test']
            bar = tqdm(total=coll_test.count(), ncols=96, desc='Testing')
            ds_test = DataSource(coll_test, False, args.batch_size)
            coll_output = db['result_vgg']
            buffer = []
            cal = ph.train.AccCalculator()
            for _id, image, label_index in ds_test:
                label_pred, y = model.predict(image)
                cal.update(label_pred, label_index)
                for _id_i, label_index_i, y_i in zip(_id, label_index, y):
                    doc = {
                        '_id': _id_i,
                        'label_index': label_index_i,
                        'y': gzip.compress(pickle.dumps(y_i), 7)
                    }
                    buffer.append(doc)
                    if len(buffer) >= 1000:
                        coll_output.insert_many(buffer)
                        buffer.clear()
                bar.update(len(image))
            if len(buffer) != 0:
                coll_output.insert_many(buffer)
                buffer.clear()
            bar.close()
            print(f'Final acc={cal.accuracy}')

        print('All clear.')
        return 0
Ejemplo n.º 11
0
def main1(args):
    log_file = 'logs_' + args.name + str(args.start)  # 加start是为了互相区分开
    model_dir = 'models_' + args.name + str(args.start)
    exp = common.Experiment(log_file,
                            model_dir)  # 在这里内次运行前会把已经存在model_dir的同名目录删除
    # we = ph.utils.WordEmbedding()  # 修改了原函数 初始化之后可以直接通过embeding查询
    bert_client = BertClient(ip='202.201.242.38')
    trainer = Main(args.name + str(args.start), 768)  # 词向量 300维  输入了模型名称和300维
    ph.initialize_global_variables()
    #
    test_list = []

    for i in range(domain_num):
        print('***********************')
        print('domain:' + str(i) + dom_list[i])
        train_data, dev_data, test_data = build_dataset_LL(
            i, args, bert_client)  # 重写的数据读取 这里是读取一个领域的数据
        train_ds = common.TrainSource(train_data,
                                      i)  # 实现了源代码中的DataSource作为Dataloader
        dev_ds = common.TrainSource(dev_data, i)
        test_ds = common.TrainSource(test_data, i)
        test_list.append(test_ds)
        #
        exp.load_model(trainer)  # 先把之前的模型拿出来接着训练 第一次就从头训
        seq_stat = trainer.stat.read_stat(trainer.flat_seq)
        states_stat = trainer.stat.read_stat(trainer.flat_states)
        trainer._optimizer.update_mask(trainer.shared.cell.wz, seq_stat, i)
        trainer._optimizer.update_mask(trainer.shared.cell.wr, seq_stat, i)
        trainer._optimizer.update_mask(trainer.shared.cell.wh, seq_stat, i)
        trainer._optimizer.update_mask(trainer.shared.cell.uz, states_stat, i)
        trainer._optimizer.update_mask(trainer.shared.cell.ur, states_stat, i)
        trainer._optimizer.update_mask(trainer.shared.cell.uh, states_stat, i)
        trainer.add_data_trainer(train_ds, 64)  # 32 batch_size
        # trainer.add_screen_logger('train', ('Loss', 'Norm'), interval=1)  # 这里是输出训练过程的
        trainer.add_data_validator(test_ds, 64, interval=20)  # 相当于model.eval()
        # trainer.add_screen_logger(  # eval的结果
        #     "validate",
        #     ('hit_pos', 'hit_neg', 'pred_pos', 'pred_neg', 'Error'),
        #     message='[%d]' % i,
        #     interval=20
        # )
        trainer.add_fitter(common.DevFitter(dev_ds, 64, 20))
        trainer.fit(args.num_loops)
        trainer.clear_fitters()  # 训练过程到此结束
        #
        exp.dump_model(trainer)  # 模型存起来

        # test turn
        for test_data in test_list:
            trainer.add_data_validator(test_data, 64,
                                       interval=1)  # model.eval()
            trainer.add_screen_logger(  # 输出测试结果
                "validate",
                ('hit_pos', 'hit_neg', 'pred_pos', 'pred_neg', 'Error'),
                message='[%d]' % i,
                interval=1)
            trainer.fit(1)
            trainer.clear_fitters()
        trainer.stat.update_stats()
    bert_client.close()
    return 0
Ejemplo n.º 12
0
    def _main(self, args):
        ph.set_tf_log_level(ph.TF_LOG_NO_WARN)
        with pymongo.MongoClient('sis3.ustcdm.org') as conn:
            conn['admin'].authenticate('root', 'SELECT * FROM users;')

            db = conn[args.db_name]
            coll_train = db[f'task_{args.task_index:02d}_train']
            coll_valid = db[f'task_{args.task_index:02d}_valid']
            num_classes = len(coll_train.distinct('label_index'))
            print(f'Found {num_classes} classes.')

            ds_train = DataSource(coll_train, True, args.batch_size)
            ds_valid = DataSource(coll_valid, False, args.batch_size)

            model = Model('model', num_classes, args.keep_prob)
            ph.initialize_global_variables()
            ph.io.load_model_from_file(model['encoder'], args.vgg16, 'vgg16')

            #
            # train the last layer
            progress = tqdm(total=args.num_train, ncols=96, desc='Training')
            for i in range(args.num_train):
                self.checkpoint()
                try:
                    _, image, label = ds_train.next()
                except StopIteration:
                    _, image, label = ds_train.next()
                loss, = model.train(image, label)
                progress.set_description(f'Training loss={loss:.06f}',
                                         refresh=False)
                progress.update()
            progress.close()

            #
            # validation
            progress = tqdm(total=coll_valid.count(),
                            ncols=96,
                            desc='Validating')
            cal = ph.train.AccCalculator()
            for _, image, label in ds_valid:
                label_pred, _ = model.predict(image)
                cal.update(label_pred, label)
                progress.update(len(image))
            progress.close()
            print(f'Validation acc={cal.accuracy}')

            #
            # fine tuning all the parameters
            progress = tqdm(total=args.num_loops, ncols=96, desc='Fine tuning')
            monitor = ph.train.EarlyStopping(5, model)
            for i in range(args.num_loops):
                self.checkpoint()
                try:
                    _, image, label = ds_train.next()
                except StopIteration:
                    _, image, label = ds_train.next()
                loss, = model.fine_tune(image, label)
                progress.set_description(f'Fine tuning loss={loss:.06f}',
                                         refresh=False)

                if (i + 1) % 1000 == 0:
                    progress_valid = tqdm(total=coll_valid.count(),
                                          ncols=96,
                                          desc='Validating')
                    cal = ph.train.AccCalculator()
                    for _, image, label in ds_valid:
                        label_pred, _ = model.predict(image)
                        cal.update(label_pred, label)
                        progress_valid.update(len(image))
                    progress_valid.close()
                    progress.clear()
                    print(f'[{i + 1}] Validation acc={cal.accuracy}')
                    if monitor.convergent(1 - cal.accuracy):
                        model.set_parameters(monitor.best_parameters)
                        break
                progress.update()
            progress.close()
            ds_train = None
            ds_valid = None

            if args.write_results:
                coll = conn['imagenet_vgg']['train']
                coll_output = db[f'result_{args.task_index:02d}_train']
                self._write_result(model, coll, coll_output)

                coll = conn['imagenet_vgg']['valid']
                coll_output = db[f'result_{args.task_index:02d}_valid']
                self._write_result(model, coll, coll_output)

                coll = conn['imagenet_vgg']['test']
                coll_output = db[f'result_{args.task_index:02d}_test']
                self._write_result(model, coll, coll_output)

        print('All clear.')
        return 0