Ejemplo n.º 1
0
def main():
    args = get_arguments()

    if not os.path.exists(args.output_folder):
        os.mkdir(args.output_folder)

    if (not args.random_weights) and (args.ckpt_path is None):
        print("Path to the checkpoint file was not provided")
        exit(1)

    if args.model == 'srcnn':
        model = SRCNN(args)
    elif args.model == 'espcn':
        model = ESPCN(args)
    elif args.model == 'ldsp':
        model = LDSP(args)
    elif args.model == 'vespcn':
        model = VESPCN(args)
    elif args.model == 'vsrnet':
        model = VSRnet(args)
    else:
        exit(1)

    with tf.Session() as sess:
        input_ph = model.get_placeholder()
        predicted = model.load_model(input_ph)

        if args.model == 'vespcn':
            predicted = predicted[2]
        predicted = tf.identity(predicted, name='y')

        if args.random_weights:
            print("Random Weights Loaded.")
            init = tf.global_variables_initializer()
            sess.run(init)
        else:
            print("Checkpoint Weights Loaded.")
            if os.path.isdir(args.ckpt_path):
                args.ckpt_path = tf.train.latest_checkpoint(args.ckpt_path)
            saver = tf.train.Saver()
            saver.restore(sess, args.ckpt_path)

        output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['y'])
        tf.train.write_graph(output_graph_def, args.output_folder, args.model + '.pb', as_text=False)
Ejemplo n.º 2
0
def main():
    args = get_arguments()

    if args.model == 'srcnn':
        model = SRCNN(args)
    elif args.model == 'espcn':
        model = ESPCN(args)
    elif args.model == 'ldsp':
        model = LDSP(args)
    elif args.model == 'vespcn':
        model = VESPCN(args)
    elif args.model == 'vsrnet':
        model = VSRnet(args)
    else:
        exit(1)

    with tf.Session() as sess:
        data_batch, data_initializer = model.get_data()

        predicted_batch = model.load_model(data_batch)

        loss = model.get_loss(data_batch, predicted_batch)

        global_step = tf.Variable(0, trainable=False)
        if args.use_lr_decay:
            lr = tf.train.exponential_decay(args.learning_rate,
                                            global_step,
                                            args.lr_decay_epochs *
                                            model.dataset.examples_num,
                                            args.lr_decay_rate,
                                            staircase=args.staircase_lr_decay)

        else:
            lr = args.learning_rate
        if args.optimizer == 'adam':
            optimizer = tf.train.AdamOptimizer(lr)
        elif args.optimizer == 'momentum':
            optimizer = tf.train.MomentumOptimizer(lr, args.momentum)
        elif args.optimizer == 'sgd':
            optimizer = tf.train.GradientDescentOptimizer(lr)
        grads_vars = optimizer.compute_gradients(loss)
        grads_vars_final = []
        for gradient, variable in grads_vars:
            assert gradient is not None, variable.name

            if variable.name in model.lr_multipliers.keys():
                gradient *= model.lr_multipliers[variable.name]
            grads_vars_final.append((gradient, variable))

            variable_name = variable.name.replace(':', '_')

            scope = 'TrainLogs/' + variable_name + '/Values/'
            tf.summary.scalar(scope + 'MIN', tf.reduce_min(variable))
            tf.summary.scalar(scope + 'MAX', tf.reduce_max(variable))
            tf.summary.scalar(scope + 'L2', tf.norm(variable))
            tf.summary.scalar(scope + 'AVG', tf.reduce_mean(variable))

            scope = 'TrainLogs/' + variable_name + '/Gradients/'
            tf.summary.scalar(scope + 'MIN', tf.reduce_min(gradient))
            tf.summary.scalar(scope + 'MAX', tf.reduce_max(gradient))
            tf.summary.scalar(scope + 'L2', tf.norm(gradient))
            tf.summary.scalar(scope + 'AVG', tf.reduce_mean(gradient))
        train_op = optimizer.apply_gradients(grads_vars_final,
                                             global_step=global_step)
        tf.summary.scalar('Learning_rate', lr)

        summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.logdir, sess.graph)

        saver = tf.train.Saver()
        last_epoch = 0
        if args.ckpt_path is None:
            init_op = tf.group(tf.global_variables_initializer(),
                               tf.local_variables_initializer())
            sess.run(init_op)
        else:
            if os.path.isdir(args.ckpt_path):
                args.ckpt_path = tf.train.latest_checkpoint(args.ckpt_path)
            last_epoch = int(args.ckpt_path.split('.')[0].split('_')[-1])
            saver.restore(sess, args.ckpt_path)
        sess.run(data_initializer)

        num_steps_in_epoch = model.dataset.examples_num // args.batch_size + \
                             1 if model.dataset.examples_num % args.batch_size != 0 else 0
        for epoch in range(args.num_epochs):
            print("Epoch: ", epoch + last_epoch)
            bar = tqdm(range(num_steps_in_epoch),
                       total=num_steps_in_epoch,
                       unit='step',
                       smoothing=1.0)
            for i in bar:
                _, cur_loss, cur_summary, = sess.run([train_op, loss, summary])
                bar.set_description('Loss: ' + str(cur_loss))
                bar.refresh()
                if (i + 1) % args.steps_per_log == 0:
                    summary_writer.add_summary(
                        cur_summary,
                        (last_epoch + epoch) * num_steps_in_epoch + i)
            if epoch % args.epochs_per_save == 0:
                saver.save(
                    sess,
                    os.path.join(
                        args.logdir,
                        'model_' + str(last_epoch + epoch + 1) + '.ckpt'))
Ejemplo n.º 3
0
def main():
    args = get_arguments()

    if args.model == 'srcnn':
        model = SRCNN(args)
    elif args.model == 'espcn':
        model = ESPCN(args)
    elif args.model == 'ldsp':
        model = LDSP(args)
    elif args.model == 'vespcn':
        model = VESPCN(args)
    elif args.model == 'vsrnet':
        model = VSRnet(args)
    else:
        exit(1)

    with tf.Session() as sess:
        data_batch, data_initializer = model.get_data()

        predicted_batch = model.load_model(data_batch)

        metrics = model.calculate_metrics(data_batch, predicted_batch)

        if args.ckpt_path is None:
            print("Path to the checkpoint file was not provided")
            exit(1)

        if os.path.isdir(args.ckpt_path):
            args.ckpt_path = tf.train.latest_checkpoint(args.ckpt_path)
        saver = tf.train.Saver()
        saver.restore(sess, args.ckpt_path)

        summary = tf.summary.merge_all()
        summary_writer = tf.summary.FileWriter(args.logdir, sess.graph)

        sess.run(data_initializer)

        steps = model.dataset.examples_num // args.batch_size + (
            1 if model.dataset.examples_num % args.batch_size > 0 else 0)
        epoch = int(args.ckpt_path.split('.')[0].split('_')[-1])
        logged_iterations = 0
        metrics_results = [[metric[0], np.array([])] for metric in metrics]
        time_res = 0.0
        for i in tqdm(range(steps), total=steps, unit='step'):
            start = time()
            results = sess.run([metric[1] for metric in metrics] + [summary])
            time_res += time() - start
            cur_metrics_results = results[:-1]
            for j in range(len(cur_metrics_results)):
                if len(cur_metrics_results[j].shape) == len(
                        metrics_results[j][1].shape):
                    metrics_results[j][1] = np.concatenate(
                        (metrics_results[j][1], cur_metrics_results[j]))
                else:
                    metrics_results[j][1] = np.concatenate(
                        (metrics_results[j][1], [cur_metrics_results[j]]))
            cur_summary_results = results[-1]
            if (i + 1) % args.steps_per_log == 0:
                summary_writer.add_summary(cur_summary_results,
                                           epoch * steps + logged_iterations)
                logged_iterations += 1

        mean_metrics = [(metric[0], np.mean(metric[1]))
                        for metric in metrics_results]
        mean_metrics.append(("Time", time_res / model.dataset.examples_num))
        metric_summaries = []
        for metric in mean_metrics:
            print("Mean " + metric[0] + ': ', metric[1])
            metric_summaries.append(tf.summary.scalar(metric[0], metric[1]))

        metric_summary = tf.summary.merge(metric_summaries)
        metric_summary_res = sess.run(metric_summary)
        summary_writer.add_summary(metric_summary_res, epoch)
Ejemplo n.º 4
0
def main():
    args = get_arguments()

    if not os.path.exists(args.output_folder):
        os.mkdir(args.output_folder)

    if args.ckpt_path is None:
        print("Path to the checkpoint file was not provided")
        exit(1)

    if args.model == 'srcnn':
        model = SRCNN(args)
    elif args.model == 'espcn':
        model = ESPCN(args)
    elif args.model == 'vespcn':
        model = VESPCN(args)
    elif args.model == 'vsrnet':
        model = VSRnet(args)
    else:
        exit(1)

    with tf.Session() as sess:
        input_ph = model.get_placeholder()
        predicted = model.load_model(input_ph)

        if args.model == 'vespcn':
            predicted = predicted[2]
        predicted = tf.identity(predicted, name='y')

        if os.path.isdir(args.ckpt_path):
            args.ckpt_path = tf.train.latest_checkpoint(args.ckpt_path)
        saver = tf.train.Saver()
        saver.restore(sess, args.ckpt_path)

        with open(os.path.join(args.output_folder, args.model + '.model'),
                  'wb') as native_mf:
            weights = model.get_model_weights(sess)
            if args.model == 'srcnn':
                prepare_native_mf_srcnn(weights, native_mf)
            elif args.model == 'espcn':
                prepare_native_mf_espcn(weights, native_mf, args.scale_factor)
            elif args.model == 'vespcn':
                prepare_native_mf_vespcn(weights, native_mf, args.scale_factor)
            elif args.model == 'vsrnet':
                prepare_native_mf_vsrnet(weights, native_mf)

        with open(os.path.join(args.output_folder, 'dnn_' + args.model + '.h'),
                  'w') as header:
            header.write('/**\n')
            header.write(' * @file\n')
            header.write(' * Default cnn weights for x' +
                         str(args.scale_factor) + ' upscaling with ' +
                         args.model + ' model.\n')
            header.write(' */\n\n')

            header.write('#ifndef AVFILTER_DNN_' + args.model.upper() + '_H\n')
            header.write('#define AVFILTER_DNN_' + args.model.upper() + '_H\n')

            variables = tf.trainable_variables()
            var_dict = OrderedDict()
            for variable in variables:
                var_name = variable.name.split(':')[0].replace('/', '_')
                value = variable.eval()
                if 'kernel' in var_name:
                    value = np.transpose(value, axes=(3, 0, 1, 2))
                var_dict[var_name] = value

            for name, value in var_dict.items():
                dump_to_file(header, value, name)

            header.write('#endif\n')

        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, sess.graph_def, ['y'])
        tf.train.write_graph(output_graph_def,
                             args.output_folder,
                             args.model + '.pb',
                             as_text=False)