def main():
    # convert to train mode
    config.MODE = 'test'
    extra()

    # create a logger
    logger = create_logger(config, 'test')

    # logging configurations
    logger.info(pprint.pformat(config))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    # create a model
    os.environ["CUDA_VISIBLE_DEVICES"] = config.GPUS
    gpus = [int(i) for i in config.GPUS.split(',')]
    gpus = range(gpus.__len__())

    model_rgb = create_model()
    model_rgb.my_load_state_dict(torch.load(config.TEST.STATE_DICT_RGB),
                                 strict=True)
    model_rgb = model_rgb.cuda(gpus[0])

    model_flow = create_model()
    model_flow.my_load_state_dict(torch.load(config.TEST.STATE_DICT_FLOW),
                                  strict=True)
    model_flow = model_flow.cuda(gpus[0])

    # load data
    test_dataset_rgb = get_dataset(mode='test', modality='rgb')
    test_dataset_flow = get_dataset(mode='test', modality='flow')

    test_loader_rgb = torch.utils.data.DataLoader(
        test_dataset_rgb,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)
    test_loader_flow = torch.utils.data.DataLoader(
        test_dataset_flow,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    result_file_path = test_final(test_dataset_rgb, model_rgb,
                                  test_dataset_flow, model_flow)
    eval_mAP(config.DATASET.GT_JSON_PATH, result_file_path)
Example #2
0
    def __init__(self, settings):
        super(SRTrainer, self).__init__(settings)
        self.scale = settings.scale
        self.criterion = ComLoss(settings.iqa_model_path,
                                 settings.__dict__.get('weights'),
                                 settings.__dict__.get('feat_names'),
                                 settings.alpha, settings.iqa_patch_size,
                                 settings.criterion)
        if hasattr(self.criterion, 'iqa_loss'):
            # For saving cost
            self.criterion.iqa_loss.freeze()

        self.model = build_model(ARCH, scale=self.scale)
        self.dataset = get_dataset(DATASET)

        if self.phase == 'train':
            self.train_loader = torch.utils.data.DataLoader(
                self.dataset(self.data_dir,
                             'train',
                             self.scale,
                             list_dir=self.list_dir,
                             transform=Compose(
                                 MSCrop(self.scale, settings.patch_size),
                                 Flip()),
                             repeats=settings.reproduce),
                batch_size=self.
                batch_size,  #max(self.batch_size//settings.reproduce, 1),
                shuffle=True,
                num_workers=settings.num_workers,
                pin_memory=True,
                drop_last=True)

        self.val_loader = self.dataset(self.data_dir,
                                       'val',
                                       self.scale,
                                       subset=settings.subset,
                                       list_dir=self.list_dir)

        if not self.val_loader.lr_avai:
            self.logger.warning(
                "warning: the low-resolution sources are not available")

        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          betas=(0.9, 0.999),
                                          lr=self.lr,
                                          weight_decay=settings.weight_decay)
        # self.optimizer = torch.optim.RMSprop(
        #     self.model.parameters(),
        #     lr=self.lr,
        #     alpha=0.9,
        #     weight_decay=settings.weight_decay
        # )

        self.logger.dump(self.model)  # Log the architecture
Example #3
0
    def __init__(self, path_to_model_weight, patch_size, feat_names):
        super(IQALoss, self).__init__()

        self.iqa_model = IQANet(weighted=False)
        if os.path.exists(path_to_model_weight):
            self.iqa_model.load_state_dict(
                torch.load(path_to_model_weight)['state_dict'])

        self.patch_size = patch_size
        # Strip nr and invalid names
        self.feat_names = [
            n for n in feat_names
            # not nr
            if n != 'nr' and
            # exists in the model
            hasattr(self.iqa_model, n)
        ]
        self.regular = 'nr' in feat_names
        self._denorm = get_dataset(DATASET).denormalize
Example #4
0
    def __init__(self,
                 scale,
                 data_dir,
                 ckp_path,
                 save_lr=False,
                 list_dir='',
                 out_dir='./',
                 log_dir=''):
        super(SRPredictor, self).__init__(data_dir=data_dir,
                                          ckp_path=ckp_path,
                                          save_lr=save_lr,
                                          list_dir=list_dir,
                                          out_dir=out_dir,
                                          log_dir=log_dir)
        self.scale = scale

        self.model = build_model(constants.ARCH, scale=self.scale)
        self.dataset = get_dataset(constants.DATASET)

        self.test_loader = self.dataset(self.data_dir,
                                        'test',
                                        self.scale,
                                        list_dir=self.list_dir)
Example #5
0
def training_loop(config: Config):
    timer = Timer()
    print("Start task {}".format(config.task_name))
    dataset = get_dataset(name=config.dataset, data_dir=config.data_dir, seed=config.seed)

    with tf.device('/cpu:0'):
        print("Constructing networks...")
        Network = fsGAN.Network(dataset=dataset, model_dir=config.model_dir, run_dir=config.run_dir)
        data_iter = Network.input_data_as_iter(
            batch_size=config.batch_size // config.gpu_nums, seed=config.seed, mode="train")

        eval_iter = Network.input_data_as_iter(
            batch_size=config.batch_size // config.gpu_nums, seed=config.seed, mode="eval")
        global_step = tf.compat.v1.get_variable(
            'global_step', [],
            initializer=tf.constant_initializer(0), trainable=False)
    print("Building Tensorflow graph...")
    g_grad_pool = []
    d_grad_pool = []
    for gpu in range(config.gpu_nums):
        with tf.name_scope("GPU%d" % gpu), tf.device('/gpu:%d' % gpu):
            fs, ls = data_iter.get_next()
            fs, ls = Network.generate_samples(fs, ls)
            g_loss, d_loss = Network.create_loss(fs, ls)
            g_op = Network.get_gen_optimizer()
            d_op = Network.get_disc_optimizer()
            with tf.control_dependencies([g_loss]):
                g_grad_pool.append(g_op.compute_gradients(g_loss, Network.generator.trainable_variables))
            with tf.control_dependencies([d_loss]):
                d_grad_pool.append(d_op.compute_gradients(d_loss, Network.discriminator.trainable_variables))
    with tf.device('/cpu:0'):
        g_update_op = Network.update(g_grad_pool, g_op)
        d_update_op = Network.update(d_grad_pool, d_op)
        g_ma_op = Network.ma_op(global_step=global_step)
        merge_op = Network.summary()
        f_eval, l_eval = eval_iter.get_next()
        [inception_score, fid] = Network.eval(f_eval, l_eval)

    saver = tf.train.Saver()
    print('Start training...\n')
    # with tf.Session(config=tf.ConfigProto(
    #         allow_soft_placement=True)) as sess:
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run([data_iter.initializer, eval_iter.initializer])
        fsnap, lsnap = sess.run(data_iter.get_next())
        fakes, _ = Network.generate_samples(fsnap, lsnap, is_training=False)
        save_image_grid(fsnap["images"], filename=config.model_dir + '/reals.png')
        summary_writer = tf.summary.FileWriter(logdir=config.model_dir, graph=sess.graph)
        for step in range(config.total_step):
            for D_repeat in range(Network.disc_iters):
                sess.run([d_update_op, g_ma_op])
            sess.run([g_update_op])
            if step % config.summary_per_steps == 0:
                summary_file = sess.run(merge_op)
                summary_writer.add_summary(summary_file, step)
            if step % config.eval_per_steps == config.eval_per_steps // 2:
                timer.update()
                fakes_np = sess.run(fakes["generated"])
                save_image_grid(fakes_np, filename=config.model_dir + '/fakes%06d.png' % step)
                [inception_score_eval, fid_eval] = sess.run([inception_score, fid])
                print("Time %s, fid %f, inception_score %f ,step %d" %
                      (timer.runing_time, fid_eval, inception_score_eval, step))
            if step % config.save_per_steps == 0:
                saver.save(sess, save_path=config.model_dir + '/model.ckpt', global_step=step)
        # define model
        inputs = tf.keras.Input(shape=(params['height'], params['width'], 3), name='modelInput')
        outputs = SpmModel(inputs, num_joints=params['joints'], is_training=True)
        model = tf.keras.Model(inputs, outputs)


        model.compile(loss={'root_joints_conv1x1': MSELoss, 'reg_map_conv1x1': SmoothL1Loss},
                      loss_weights={'root_joints_conv1x1': 1, 'reg_map_conv1x1': 1},
                      optimizer=tf.keras.optimizers.Adam(1e-4))

        if params['finetune'] is not None:
            model.load_weights(params['finetune'])


    # define dataset
    train_dataset = get_dataset(num_gpus=len(gpu_ids), mode='train')
    test_dataset  = get_dataset(num_gpus=len(gpu_ids), mode='val')
    def generator(dataset):
        for input, output1, output2 in dataset:
            yield input, {'root_joints_conv1x1':output1, 'reg_map_conv1x1':output2}


    def step_lr(epoch):
        if epoch < 20:
            return 1e-3
        elif epoch < 50:
            return 1e-4
        else:
            return 1e-5

    # def callbacks
Example #7
0
import cv2
import numpy as np
import os

use_dataset = False

if use_dataset:
    import tensorflow as tf
    from decoder.decode_spm import SpmDecoder
    from dataset.dataset import get_dataset
    from config.spm_config import spm_config as params
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    mode = 'train'
    dataset = get_dataset(mode=mode)

    colors = [[0,0,255],[255,0,0],[0,255,0]]
    for epco in range(1):
        for step, (img, center_map, center_mask, kps_map, kps_map_weight) in enumerate(dataset):
            # print (step)
            # print (img[0].shape)
            # img1 = img[0]
            # label1 = label[0]
            # break
            print ('epoch {} / step {}'.format(epco, step))
            img = (img.numpy()[0] * 255).astype(np.uint8)

            spm_decoder = SpmDecoder(4, 4, params['height']//4, params['width']//4)
            results = spm_decoder([center_map[0].numpy(), kps_map[0].numpy()])
Example #8
0
    use_nms = True

    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

    if use_gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    inputs = tf.keras.Input(shape=(netH, netW, 3), name='modelInput')
    outputs = SpmModel(inputs, 14, is_training=False)
    model = tf.keras.Model(inputs, outputs)
    ckpt = tf.train.Checkpoint(net=model)
    ckpt.restore(ckpt_path)

    val_dataset = get_dataset(mode='val')
    predictions = []

    for step, (imgids, heights, widths, imgs) in enumerate(val_dataset):
        center_map, kps_reg_map = model(imgs)
        imgids = imgids.numpy()
        for b in range(params['batch_size']):
            factor_x = widths[b].numpy() / (netW / 4)
            factor_y = heights[b].numpy() / (netH / 4)
            spm_decoder = SpmDecoder(factor_x, factor_y, netH // 4, netW // 4)
            joints, centers = spm_decoder([center_map[b], kps_reg_map[b]],
                                          score_thres=score,
                                          dis_thres=dist)
            img_id = str(imgids[b], encoding='utf-8')

            predict = {}
Example #9
0
        inputs = tf.keras.Input(shape=(params['height'], params['width'], 3),
                                name='modelInput')
        outputs = SpmModel(inputs,
                           num_joints=params['joints'],
                           is_training=True)
        model = tf.keras.Model(inputs, outputs)
        optimizer = tf.optimizers.Adam(learning_rate=3e-4)
        checkpoint = tf.train.Checkpoint(optimizer=optimizer, net=model)
        if params['finetune'] is not None:
            checkpoint.restore(
                params['finetune']).assert_existing_objects_matched()
            print('Successfully restore model from {}'.format(
                params['finetune']))

    with strategy.scope():
        dataset = get_dataset(len(gpu_ids))
        dist_dataset = strategy.experimental_distribute_dataset(dataset)
        print(dist_dataset.__dict__['_cloned_datasets'])

    with strategy.scope():

        def SmoothL1Loss(label, pred, weight):
            t = tf.abs(label * weight - pred * weight)

            return tf.reduce_mean(tf.where(t <= 1, 0.5 * t * t, 0.5 * (t - 1)))

        def L2Loss(label, pred):
            return tf.reduce_mean(tf.losses.mse(label, pred))

        def comput_loss(center_map, kps_map, kps_map_weight, preds):
            kps_loss = SmoothL1Loss(kps_map, preds[1], kps_map_weight)
Example #10
0
from nets.spm_model import SpmModel
from config.center_config import center_config
from train.spm_train import train

import os
import datetime

if __name__ == '__main__':

    os.environ['CUDA_VISIBLE_DEVICES'] = '0'

    inputs = tf.keras.Input(shape=(center_config['height'],
                                   center_config['width'], 3),
                            name='modelInput')
    # outputs = center_net_model(inputs, num_joints=center_config['joints'], training=True)
    outputs = SpmModel(inputs,
                       num_joints=center_config['joints'],
                       is_training=True)
    model = tf.keras.Model(inputs, outputs)

    cur_time = datetime.datetime.fromtimestamp(
        datetime.datetime.now().timestamp()).strftime('%Y-%m-%d-%H-%M')

    optimizer = tf.optimizers.Adam(learning_rate=3e-4)
    dataset = get_dataset()
    epochs = 200
    summary_writer = tf.summary.create_file_writer(
        os.path.join('./logs/spm', cur_time))
    with summary_writer.as_default():
        train(model, optimizer, dataset, epochs, cur_time)
def main():
    # convert to train mode
    config.MODE = 'train'
    extra()

    # create a logger
    logger = create_logger(config, 'train')

    # logging configurations
    logger.info(pprint.pformat(config))

    # random seed
    if config.IF_DETERMINISTIC:
        torch.manual_seed(config.RANDOM_SEED_TORCH)
        config.CUDNN.DETERMINISTIC = True
        config.CUDNN.BENCHMARK = False
        np.random.seed(config.RANDOM_SEED_NUMPY)
        random.seed(config.RANDOM_SEED_RANDOM)
    else:
        logger.info('torch random seed: {}'.format(torch.initial_seed()))

        seed = random.randint(0, 2**32)
        np.random.seed(seed)
        logger.info('numpy random seed: {}'.format(seed))

        seed = random.randint(0, 2**32)
        random.seed(seed)
        logger.info('random random seed: {}'.format(seed))

    # cudnn related setting
    cudnn.benchmark = config.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = config.CUDNN.ENABLED

    # create a model
    gpus = [int(i) for i in config.GPUS.split(',')]

    model_rgb = create_model()
    if config.TRAIN.RESUME_RGB:
        model_rgb.my_load_state_dict(torch.load(config.TRAIN.STATE_DICT_RGB),
                                     strict=True)

    model_rgb = model_rgb.cuda(gpus[0])
    model_rgb = torch.nn.DataParallel(model_rgb, device_ids=gpus)

    model_flow = create_model()
    if config.TRAIN.RESUME_FLOW:
        model_flow.my_load_state_dict(torch.load(config.TRAIN.STATE_DICT_FLOW),
                                      strict=True)

    model_flow = model_flow.cuda(gpus[0])
    model_flow = torch.nn.DataParallel(model_flow, device_ids=gpus)

    # create a conditional-vae
    cvae_rgb = create_cvae()
    cvae_rgb = cvae_rgb.cuda(gpus[0])
    cvae_rgb = torch.nn.DataParallel(cvae_rgb, device_ids=gpus)

    cvae_flow = create_cvae()
    cvae_flow = cvae_flow.cuda(gpus[0])
    cvae_flow = torch.nn.DataParallel(cvae_flow, device_ids=gpus)

    # create an optimizer
    optimizer_rgb = create_optimizer(config, model_rgb)
    optimizer_flow = create_optimizer(config, model_flow)
    optimizer_cvae_rgb = create_optimizer(config, cvae_rgb)
    optimizer_cvae_flow = create_optimizer(config, cvae_flow)

    # create a learning rate scheduler
    lr_scheduler_rgb = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_rgb,
        T_max=(config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH) //
        config.TRAIN.TEST_EVERY_EPOCH,
        eta_min=config.TRAIN.LR / 10)
    lr_scheduler_flow = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_flow,
        T_max=(config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH) //
        config.TRAIN.TEST_EVERY_EPOCH,
        eta_min=config.TRAIN.LR / 10)
    lr_scheduler_cvae_rgb = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_cvae_rgb,
        T_max=(config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH) //
        config.TRAIN.TEST_EVERY_EPOCH,
        eta_min=config.TRAIN.LR / 10)
    lr_scheduler_cvae_flow = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer_cvae_flow,
        T_max=(config.TRAIN.END_EPOCH - config.TRAIN.BEGIN_EPOCH) //
        config.TRAIN.TEST_EVERY_EPOCH,
        eta_min=config.TRAIN.LR / 10)

    # load data
    train_dataset_rgb = get_dataset(mode='train', modality='rgb')
    train_dataset_flow = get_dataset(mode='train', modality='flow')
    test_dataset_rgb = get_dataset(mode='test', modality='rgb')
    test_dataset_flow = get_dataset(mode='test', modality='flow')

    train_loader_rgb = torch.utils.data.DataLoader(
        train_dataset_rgb,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=True,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True)
    train_loader_flow = torch.utils.data.DataLoader(
        train_dataset_flow,
        batch_size=config.TRAIN.BATCH_SIZE * len(gpus),
        shuffle=True,
        num_workers=config.WORKERS,
        pin_memory=True,
        drop_last=True)
    test_loader_rgb = torch.utils.data.DataLoader(
        test_dataset_rgb,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)
    test_loader_flow = torch.utils.data.DataLoader(
        test_dataset_flow,
        batch_size=config.TEST.BATCH_SIZE * len(gpus),
        shuffle=False,
        num_workers=config.WORKERS,
        pin_memory=True)

    # training and validating

    best_perf = 0

    best_model_rgb = create_model()
    best_model_rgb = best_model_rgb.cuda(gpus[0])

    best_model_flow = create_model()
    best_model_flow = best_model_flow.cuda(gpus[0])

    for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH,
                       config.TRAIN.TEST_EVERY_EPOCH):
        # train rgb for **config.TRAIN.TEST_EVERY_EPOCH** epochs
        train(train_loader_rgb, model_rgb, cvae_rgb, optimizer_rgb,
              optimizer_cvae_rgb, epoch, config.TRAIN.TEST_EVERY_EPOCH, 'rgb')

        # evaluate on validation set
        result_file_path_rgb = test_final(test_dataset_rgb, model_rgb.module,
                                          test_dataset_flow, best_model_flow)
        perf_indicator = eval_mAP(config.DATASET.GT_JSON_PATH,
                                  result_file_path_rgb)

        if best_perf < perf_indicator:
            logger.info("(rgb) new best perf: {:3f}".format(perf_indicator))
            best_perf = perf_indicator
            best_model_rgb.my_load_state_dict(model_rgb.state_dict(),
                                              strict=True)

            logger.info("=> saving final result into {}".format(
                os.path.join(config.OUTPUT_DIR,
                             'final_rgb_{}.pth'.format(best_perf))))
            torch.save(
                best_model_rgb.state_dict(),
                os.path.join(config.OUTPUT_DIR,
                             'final_rgb_{}.pth'.format(best_perf)))

            logger.info("=> saving final result into {}".format(
                os.path.join(config.OUTPUT_DIR,
                             'final_flow_{}.pth'.format(best_perf))))
            torch.save(
                best_model_flow.state_dict(),
                os.path.join(config.OUTPUT_DIR,
                             'final_flow_{}.pth'.format(best_perf)))

        # lr_scheduler_rgb.step(perf_indicator)
        # lr_scheduler_cvae_rgb.step()

        # train flow for **config.TRAIN.TEST_EVERY_EPOCH** epochs
        train(train_loader_flow, model_flow, cvae_flow, optimizer_flow,
              optimizer_cvae_flow, epoch, config.TRAIN.TEST_EVERY_EPOCH,
              'flow')

        # evaluate on validation set
        result_file_path_flow = test_final(test_dataset_rgb, best_model_rgb,
                                           test_dataset_flow,
                                           model_flow.module)
        perf_indicator = eval_mAP(config.DATASET.GT_JSON_PATH,
                                  result_file_path_flow)

        if best_perf < perf_indicator:
            logger.info("(flow) new best perf: {:3f}".format(perf_indicator))
            best_perf = perf_indicator
            best_model_flow.my_load_state_dict(model_flow.state_dict(),
                                               strict=True)

            logger.info("=> saving final result into {}".format(
                os.path.join(config.OUTPUT_DIR,
                             'final_rgb_{}.pth'.format(best_perf))))
            torch.save(
                best_model_rgb.state_dict(),
                os.path.join(config.OUTPUT_DIR,
                             'final_rgb_{}.pth'.format(best_perf)))

            logger.info("=> saving final result into {}".format(
                os.path.join(config.OUTPUT_DIR,
                             'final_flow_{}.pth'.format(best_perf))))
            torch.save(
                best_model_flow.state_dict(),
                os.path.join(config.OUTPUT_DIR,
                             'final_flow_{}.pth'.format(best_perf)))
Example #12
0
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    tf.reset_default_graph()

    # load config file
    config = get_config(args.config_file, args.disp_config)

    # make the assets directory and copy the config file to it
    # so if you want to reproduce the result in assets dir
    # just copy the config_file.json to ./cfgs folder and run python3 train.py --config=(config_file)
    if not os.path.exists(config['assets dir']):
        os.makedirs(config['assets dir'])
    copyfile(os.path.join('./cfgs', args.config_file + '.json'),
             os.path.join(config['assets dir'], 'config_file.json'))

    # prepare dataset
    dataset = get_dataset(config['dataset'], config['dataset params'])

    tfconfig = tf.ConfigProto()
    tfconfig.gpu_options.allow_growth = True

    with tf.Session(config=tfconfig) as sess:

        # build model
        config['model params']['assets dir'] = config['assets dir']
        model = get_model(config['model'], config['model params'])

        # start testing
        config['tester params']['assets dir'] = config['assets dir']
        trainer = get_trainer(config['tester'], config['tester params'], model)
        trainer.train(sess, dataset, model)
import numpy as np
from knearestneighbour import KNearestNeighbour
import dataset.dataset as dataset

if __name__ == '__main__':
    dataset = dataset.MnistDataset()
    (x_train, y_train), (x_test, y_test) = dataset.get_dataset()

    x_train = x_train.reshape((60000, 784))
    x_train = dataset.normalize(x_train)
    x_test = x_test.reshape((10000, 784))
    x_test = dataset.normalize(x_test)

    predicted = KNearestNeighbour().predict(x_train, y_train, x_test[:600])

    count = np.count_nonzero(predicted == y_test[:600])
    percentage = count / predicted.shape[0] * 100
    print(
        str(count) + " predicted correctly. That is " + str(percentage) + "%")

    # Benchmark with full training set but only 600 of the test rows (because on CPU it takes 0.9 sec. per row)
    # 578 predicted correctly. That is 96.33333333333334%
Example #14
0
                       num_joints=params['num_joints'],
                       is_training=True)
    model = tf.keras.Model(inputs, outputs)

    if params['finetune'] is not None:
        model.load_weights(params['finetune'])
        print('Successfully load pretrained model from ... {}'.format(
            params['finetune']))

    cur_time = datetime.datetime.fromtimestamp(
        datetime.datetime.now().timestamp()).strftime('%Y-%m-%d-%H-%M')
    summary_writer = tf.summary.create_file_writer(
        os.path.join('./logs/spm', cur_time))

    optimizer = tf.optimizers.Adam(learning_rate=1e-4)
    train_dataset = get_dataset(num_gpus=1, mode='train')
    test_dataset = get_dataset(num_gpus=1, mode='test')
    epochs = 150

    def lr_decay(epoch):
        if epoch < 90:
            return 1e-3
        elif epoch < 120:
            return 1e-4
        else:
            return 1e-5

    @tf.function
    def train_step(model, inputs):
        return model(inputs)
def get_all_metrics(gen,
                    k=None,
                    n_jobs=1,
                    device='cpu',
                    batch_size=512,
                    pool=None,
                    test=None,
                    test_scaffolds=None,
                    ptest=None,
                    ptest_scaffolds=None,
                    train=None):
    """
    Computes all available metrics between test (scaffold test)
    and generated sets of SMILES.
    Parameters:
        gen: list of generated SMILES
        k: int or list with values for unique@k. Will calculate number of
            unique molecules in the first k molecules. Default [1000, 10000]
        n_jobs: number of workers for parallel processing
        device: 'cpu' or 'cuda:n', where n is GPU device number
        batch_size: batch size for FCD metric
        pool: optional multiprocessing pool to use for parallelization

        test (None or list): test SMILES. If None, will load
            a default test set
        test_scaffolds (None or list): scaffold test SMILES. If None, will
            load a default scaffold test set
        ptest (None or dict): precalculated statistics of the test set. If
            None, will load default test statistics. If you specified a custom
            test set, default test statistics will be ignored
        ptest_scaffolds (None or dict): precalculated statistics of the
            scaffold test set If None, will load default scaffold test
            statistics. If you specified a custom test set, default test
            statistics will be ignored
        train (None or list): train SMILES. If None, will load a default
            train set
    Available metrics:
        * %valid
        * %unique@k
        * Frechet ChemNet Distance (FCD)
        * Fragment similarity (Frag)
        * Scaffold similarity (Scaf)
        * Similarity to nearest neighbour (SNN)
        * Internal diversity (IntDiv)
        * Internal diversity 2: using square root of mean squared
            Tanimoto similarity (IntDiv2)
        * %passes filters (Filters)
        * Distribution difference for logP, SA, QED, weight
        * Novelty (molecules not present in train)
    """
    if test is None:
        if ptest is not None:
            raise ValueError("You cannot specify custom test "
                             "statistics for default test set")
        test = get_dataset('test')
        ptest = get_statistics('test')

    if test_scaffolds is None:
        if ptest_scaffolds is not None:
            raise ValueError("You cannot specify custom scaffold test "
                             "statistics for default scaffold test set")
        test_scaffolds = get_dataset('test_scaffolds')
        ptest_scaffolds = get_statistics('test_scaffolds')

    train = train or get_dataset('train')

    if k is None:
        k = [1000, 10000]
    disable_rdkit_log()
    metrics = {}
    close_pool = False
    if pool is None:
        if n_jobs != 1:
            pool = Pool(n_jobs)
            close_pool = True
        else:
            pool = 1
    metrics['valid'] = fraction_valid(gen, n_jobs=pool)
    gen = remove_invalid(gen, canonize=True)
    if not isinstance(k, (list, tuple)):
        k = [k]
    for _k in k:
        metrics['unique@{}'.format(_k)] = fraction_unique(gen, _k, pool)

    if ptest is None:
        ptest = compute_intermediate_statistics(test,
                                                n_jobs=n_jobs,
                                                device=device,
                                                batch_size=batch_size,
                                                pool=pool)
    if test_scaffolds is not None and ptest_scaffolds is None:
        ptest_scaffolds = compute_intermediate_statistics(
            test_scaffolds,
            n_jobs=n_jobs,
            device=device,
            batch_size=batch_size,
            pool=pool)
    mols = mapper(pool)(get_mol, gen)
    kwargs = {'n_jobs': pool, 'device': device, 'batch_size': batch_size}
    kwargs_fcd = {'n_jobs': n_jobs, 'device': device, 'batch_size': batch_size}
    metrics['FCD/Test'] = FCDMetric(**kwargs_fcd)(gen=gen, pref=ptest['FCD'])
    metrics['SNN/Test'] = SNNMetric(**kwargs)(gen=mols, pref=ptest['SNN'])
    metrics['Frag/Test'] = FragMetric(**kwargs)(gen=mols, pref=ptest['Frag'])
    metrics['Scaf/Test'] = ScafMetric(**kwargs)(gen=mols, pref=ptest['Scaf'])
    if ptest_scaffolds is not None:
        metrics['FCD/TestSF'] = FCDMetric(**kwargs_fcd)(
            gen=gen, pref=ptest_scaffolds['FCD'])
        metrics['SNN/TestSF'] = SNNMetric(**kwargs)(
            gen=mols, pref=ptest_scaffolds['SNN'])
        metrics['Frag/TestSF'] = FragMetric(**kwargs)(
            gen=mols, pref=ptest_scaffolds['Frag'])
        metrics['Scaf/TestSF'] = ScafMetric(**kwargs)(
            gen=mols, pref=ptest_scaffolds['Scaf'])

    metrics['IntDiv'] = internal_diversity(mols, pool, device=device)
    metrics['IntDiv2'] = internal_diversity(mols, pool, device=device, p=2)
    metrics['Filters'] = fraction_passes_filters(mols, pool)

    # Properties
    for name, func in [('logP', logP), ('SA', SA), ('QED', QED),
                       ('weight', weight)]:
        metrics[name] = WassersteinMetric(func, **kwargs)(gen=mols,
                                                          pref=ptest[name])

    if train is not None:
        metrics['Novelty'] = novelty(mols, train, pool)
    enable_rdkit_log()
    if close_pool:
        pool.close()
        pool.join()
    return metrics