Пример #1
0
def test_run_random_walk_smoke_test():
    FLAGS.unparse_flags()
    FLAGS(["argv0"])
    with capture_output() as out:
        with compiler_gym.make("llvm-autophase-ic-v0") as env:
            env.benchmark = "cbench-v1/crc32"
            run_random_walk(env=env, step_count=5)

    print(out.stdout)
    # Note the ".*" before and after the step count to ignore the shell
    # formatting.
    assert re.search(r"Completed .*5.* steps in ", out.stdout)
Пример #2
0
def main():
    # Common
    flags.DEFINE_string("map_name", "CollectMineralShards", "Name of the map")
    flags.DEFINE_integer("screen_size", 32, "Feature screen size")
    flags.DEFINE_integer("minimap_size", 32, "Feature minimap size")
    flags.DEFINE_bool("visualize", False, "Show python visualisation")
    flags.DEFINE_integer(
        "save_replay_episodes", 500,
        "How often to save replays, in episodes. 0 to disable saving replays.")
    flags.DEFINE_string("replay_dir", os.path.abspath("Replays"),
                        "Directory to save replays.")

    # Environment
    flags.DEFINE_string("env", "movement.MovementEnv",
                        "Which environment to use.")

    # Algo-specific
    flags.DEFINE_integer("envs", 2,
                         "Number of sc2 environments to run in parallel")
    flags.DEFINE_float("max_timesteps", 40, "Max timesteps, in millions")

    # Algo hyperparameters
    flags.DEFINE_string("policy", "fullyconv", "The policy function to use")
    flags.DEFINE_string(
        "lrschedule", "constant",
        "Linear or constant, learning rate schedule for baselines a2c")
    flags.DEFINE_float("learning_rate", 3e-4, "Learning rate")
    flags.DEFINE_float("value_weight", 1.0, "Value function loss weight")
    flags.DEFINE_float("entropy_weight", 1e-5, "Entropy loss weight")

    FLAGS(sys.argv)
    print(sys.argv)

    train()
Пример #3
0
def test_tune_smoke_test(search: str, gcc_bin: str, capsys, tmpdir: Path):
    tmpdir = Path(tmpdir)
    flags = [
        "argv0",
        "--seed=0",
        f"--output_dir={tmpdir}",
        f"--gcc_bin={gcc_bin}",
        "--gcc_benchmark=benchmark://chstone-v0/aes",
        f"--search={search}",
        "--pop_size=3",
        "--gcc_search_budget=6",
    ]
    sys.argv = flags
    FLAGS.unparse_flags()
    FLAGS(flags)

    tune.main([])
    out, _ = capsys.readouterr()
    assert "benchmark://chstone-v0/aes" in out
    assert (tmpdir / "results.csv").is_file()
Пример #4
0
    def __init__(self):

        FLAGS(sys.argv)

        assert FLAGS.data_path != None, 'write the image_path'
        assert FLAGS.save_path != None, 'write the save_path'

        # 파일경로 가져오기
        data_path = FLAGS.data_path
        file_names = os.listdir(data_path)

        self.file_names = [file for file in file_names if file.endswith(FLAGS.extension)]

        self.numOfFile = len(self.file_names)
Пример #5
0
def main():
    flags.DEFINE_string("map_name", "MoveToBeacon", "Name of the map")
    flags.DEFINE_integer("frames", 40, "Number of frames in millions")
    flags.DEFINE_integer("step_mul", 8, "sc2 step multiplier")
    flags.DEFINE_integer("n_envs", 1, "Number of sc2 environments to run in parallel")
    flags.DEFINE_integer("resolution", 32, "sc2 resolution")
    flags.DEFINE_string("lrschedule", "constant",
        "linear or constant, learning rate schedule for baselines a2c")
    flags.DEFINE_float("learning_rate", 3e-4, "learning rate")
    flags.DEFINE_boolean("visualize", False, "show pygame visualisation")
    flags.DEFINE_float("value_weight", 1.0, "value function loss weight")
    flags.DEFINE_float("entropy_weight", 1e-5, "entropy loss weight")

    FLAGS(sys.argv)

    train()
Пример #6
0
def main():
    flags.DEFINE_string("map_name", "FindAndDefeatZerglings", "Which map to use")
    flags.DEFINE_boolean("continue_training", False, "Continue with training?")
    flags.DEFINE_integer("frames", 10, "Number of frames in millions")
    flags.DEFINE_integer("horizon", 40, "Number of steps before cutting the trajectory")
    flags.DEFINE_integer("step_mul", 8, "sc2 frame step size")
    flags.DEFINE_integer("n_envs", 1, "Number of sc2 environments to run in parallel")
    flags.DEFINE_integer("resolution", 32, "sc2 resolution")
    flags.DEFINE_float("learning_rate", 7e-4, "learning rate")
    flags.DEFINE_boolean("visualize", False, "show pygame visualisation")
    flags.DEFINE_float("value_weight", 0.5, "value function loss weight")
    flags.DEFINE_float("entropy_weight", 0.01, "entropy loss weight")
    flags.DEFINE_string("expirement_name", "lings_1", "What shall we call this model?")

    FLAGS(sys.argv)

    train()
                       action_spec, observation_spec))
        saver = tf.train.Saver(max_to_keep=5)

    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        if load_model == True:
            print('Loading Model...')
            ckpt = tf.train.get_checkpoint_state(model_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            sess.run(tf.global_variables_initializer())

        # This is where the asynchronous magic happens
        # Start the "work" process for each worker in a separate thread
        worker_threads = []
        for worker in workers:
            worker_work = lambda: worker.work(max_episode_length, gamma, sess,
                                              coord, saver)
            t = threading.Thread(target=(worker_work))
            t.start()
            sleep(0.5)
            worker_threads.append(t)
        coord.join(worker_threads)


if __name__ == '__main__':
    flags.DEFINE_string("map_name", "DefeatRoaches",
                        "Name of the map/minigame")
    FLAGS(sys.argv)
    main()
Пример #8
0
def main(_argv):
    FLAGS.every = [int(s) for s in FLAGS.every]
    FLAGS.balance = [
        True if s.lower() == 'true' or s.lower() == 't' else False
        for s in FLAGS.balance
    ]

    if FLAGS.num_workers < 0:
        FLAGS.num_workers = multiprocessing.cpu_count()

    ctx = [mx.gpu(i) for i in range(FLAGS.num_gpus)
           ] if FLAGS.num_gpus > 0 else [mx.cpu()]

    key_flags = FLAGS.get_key_flags_for_module(sys.argv[0])
    print('\n'.join(f.serialize() for f in key_flags))

    # Data augmentation, will do in dataset incase window>1 and need to be applied image-wise
    transform_test = None
    if FLAGS.feats_model is None:
        transform_test = transforms.Compose([
            transforms.Resize(FLAGS.data_shape + 32),
            transforms.CenterCrop(FLAGS.data_shape),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        if bool(FLAGS.flow):
            transform_test = transforms.Compose([
                transforms.Resize(FLAGS.data_shape + 32),
                transforms.CenterCrop(FLAGS.data_shape),
                TwoStreamNormalize()
            ])

    test_set = TennisSet(split=FLAGS.split,
                         transform=transform_test,
                         every=FLAGS.every[2],
                         padding=FLAGS.padding,
                         stride=FLAGS.stride,
                         window=FLAGS.window,
                         model_id=FLAGS.model_id,
                         split_id=FLAGS.split_id,
                         balance=False,
                         flow=bool(FLAGS.flow),
                         feats_model=FLAGS.feats_model,
                         save_feats=FLAGS.save_feats)

    print(test_set)

    test_data = gluon.data.DataLoader(test_set,
                                      batch_size=FLAGS.batch_size,
                                      shuffle=False,
                                      num_workers=FLAGS.num_workers)

    # Define Model
    model = None
    if FLAGS.feats_model is None:
        if FLAGS.backbone == 'rdnet':
            backbone_net = get_r21d(num_layers=34,
                                    n_classes=400,
                                    t=8,
                                    pretrained=True).features
        else:
            if FLAGS.flow == 'sixc':
                backbone_net = get_model(
                    FLAGS.backbone, pretrained=False
                ).features  # 6 channel input, don't want pretraind
            else:
                backbone_net = get_model(FLAGS.backbone,
                                         pretrained=True).features

        if FLAGS.flow in ['twos', 'only']:
            if FLAGS.flow == 'only':
                backbone_net = None
            flow_net = get_model(
                FLAGS.backbone, pretrained=True
            ).features  # todo orig exp was not pretrained flow
            model = TwoStreamModel(backbone_net, flow_net,
                                   len(test_set.classes))
        elif FLAGS.backbone == 'rdnet':
            model = FrameModel(backbone_net, len(test_set.classes), swap=True)
        else:
            model = FrameModel(backbone_net, len(test_set.classes))
    elif FLAGS.temp_pool in ['max', 'mean']:
        backbone_net = get_model(FLAGS.backbone, pretrained=True).features
        model = FrameModel(backbone_net, len(test_set.classes))
    if FLAGS.window > 1:  # Time Distributed RNN

        if FLAGS.backbone_from_id and model is not None:
            if os.path.exists(
                    os.path.join('models', 'vision', 'experiments',
                                 FLAGS.backbone_from_id)):
                files = os.listdir(
                    os.path.join('models', 'vision', 'experiments',
                                 FLAGS.backbone_from_id))
                files = [f for f in files if f[-7:] == '.params']
                if len(files) > 0:
                    files = sorted(files,
                                   reverse=True)  # put latest model first
                    model_name = files[0]
                    model.load_parameters(
                        os.path.join('models', 'vision', 'experiments',
                                     FLAGS.backbone_from_id, model_name))
                    print('Loaded backbone params: {}'.format(
                        os.path.join('models', 'vision', 'experiments',
                                     FLAGS.backbone_from_id, model_name)))

        if FLAGS.freeze_backbone and model is not None:
            for param in model.collect_params().values():
                param.grad_req = 'null'

        if FLAGS.temp_pool in ['gru', 'lstm']:
            model = CNNRNN(model,
                           num_classes=len(test_set.classes),
                           type=FLAGS.temp_pool,
                           hidden_size=128)
        elif FLAGS.temp_pool in ['mean', 'max']:
            pass
        else:
            assert FLAGS.backbone == 'rdnet'  # ensure 3d net
            assert FLAGS.window in [8, 32]

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        model.initialize()

    num_channels = 3
    if bool(FLAGS.flow):
        num_channels = 6
    if FLAGS.feats_model is None:
        if FLAGS.window == 1:
            print(
                model.summary(
                    mx.nd.ndarray.ones(shape=(1, num_channels,
                                              FLAGS.data_shape,
                                              FLAGS.data_shape))))
        else:
            print(
                model.summary(
                    mx.nd.ndarray.ones(shape=(1, FLAGS.window, num_channels,
                                              FLAGS.data_shape,
                                              FLAGS.data_shape))))
    else:
        if FLAGS.window == 1:
            print(model.summary(mx.nd.ndarray.ones(shape=(1, 4096))))
        elif FLAGS.temp_pool not in ['max', 'mean']:
            print(
                model.summary(mx.nd.ndarray.ones(shape=(1, FLAGS.window,
                                                        4096))))

    model.collect_params().reset_ctx(ctx)
    model.hybridize()

    if FLAGS.save_feats:
        best_score = -1
        best_epoch = -1
        with open(
                os.path.join('models', 'vision', 'experiments', FLAGS.model_id,
                             'scores.txt'), 'r') as f:
            lines = f.readlines()
            lines = [line.rstrip().split() for line in lines]
            for ep, sc in lines:
                if float(sc) > best_score:
                    best_epoch = int(ep)
                    best_score = float(sc)

        print('Testing best model from Epoch %d with score of %f' %
              (best_epoch, best_score))
        model.load_parameters(
            os.path.join('models', 'vision', 'experiments', FLAGS.model_id,
                         "{:04d}.params".format(best_epoch)))
        print('Loaded model params: {}'.format(
            os.path.join('models', 'vision', 'experiments', FLAGS.model_id,
                         "{:04d}.params".format(best_epoch))))

        for data, sett in zip([test_data], [test_set]):
            save_features(model, data, sett, ctx)
        return

    if os.path.exists(
            os.path.join('models', 'vision', 'experiments', FLAGS.model_id)):
        files = os.listdir(
            os.path.join('models', 'vision', 'experiments', FLAGS.model_id))
        files = [f for f in files if f[-7:] == '.params']
        if len(files) > 0:
            files = sorted(files, reverse=True)  # put latest model first
            model_name = files[0]
            model.load_parameters(os.path.join('models', 'vision',
                                               'experiments', FLAGS.model_id,
                                               model_name),
                                  ctx=ctx)
            print('Loaded model params: {}'.format(
                os.path.join('models', 'vision', 'experiments', FLAGS.model_id,
                             model_name)))

    # Setup Metrics
    test_metrics = [
        Accuracy(label_names=test_set.classes),
        mx.metric.TopKAccuracy(5, label_names=test_set.classes),
        Accuracy(name='accuracy_no',
                 label_names=test_set.classes[1:],
                 ignore_labels=[0]),
        Accuracy(name='accuracy_o',
                 label_names=test_set.classes[0],
                 ignore_labels=list(range(1, len(test_set.classes)))),
        PRF1(label_names=test_set.classes)
    ]

    # model training complete, test it
    if FLAGS.temp_pool not in ['max', 'mean']:
        mod_path = os.path.join('models', 'vision', 'experiments',
                                FLAGS.model_id)
    else:
        mod_path = os.path.join('models', 'vision', 'experiments',
                                FLAGS.feats_model)
    best_score = -1
    best_epoch = -1
    with open(os.path.join(mod_path, 'scores.txt'), 'r') as f:
        lines = f.readlines()
        lines = [line.rstrip().split() for line in lines]
        for ep, sc in lines:
            if float(sc) > best_score:
                best_epoch = int(ep)
                best_score = float(sc)

    print('Testing best model from Epoch %d with score of %f' %
          (best_epoch, best_score))
    model.load_parameters(
        os.path.join(mod_path, "{:04d}.params".format(best_epoch)))
    print('Loaded model params: {}'.format(
        os.path.join(mod_path, "{:04d}.params".format(best_epoch))))

    if FLAGS.temp_pool in ['max', 'mean']:
        assert FLAGS.backbone_from_id or FLAGS.feats_model  # if we doing temporal pooling ensure that we have loaded a pretrained net
        model = TemporalPooling(model,
                                pool=FLAGS.temp_pool,
                                num_classes=0,
                                feats=FLAGS.feats_model != None)

    tic = time.time()

    results, gts = evaluate_model(model, test_data, test_set, test_metrics,
                                  ctx)

    str_ = 'Test set:'
    for i in range(len(test_set.classes)):
        str_ += '\n'
        for j in range(len(test_set.classes)):
            str_ += str(test_metrics[4].mat[i, j]) + '\t'
    print(str_)

    str_ = '[Finished] '
    for metric in test_metrics:
        result = metric.get()
        if not isinstance(result, list):
            result = [result]
        for res in result:
            str_ += ', Test_{}={:.3f}'.format(res[0], res[1])
        metric.reset()

    str_ += '  # Samples: {}, Time Taken: {:.1f}'.format(
        len(test_set),
        time.time() - tic)
    print(str_)

    if FLAGS.vis:
        visualise_events(test_set,
                         results,
                         video_path=os.path.join('models', 'vision',
                                                 'experiments', FLAGS.model_id,
                                                 'results.mp4'),
                         gt=gts)
Пример #9
0
def detector(images_coming, threshold, prop):
    FLAGS(sys.argv)
    config = ConfigProto()
    config.gpu_options.allow_growth = True
    input_size = prop['size']

    # load model
    saved_model_loaded = tf.saved_model.load(prop['weights'],
                                             tags=[tag_constants.SERVING])

    # loop through images in list and run Yolov4 model on each
    for count, org_image in enumerate(images_coming, 1):
        original_image = cv2.cvtColor(org_image, 1)
        image_data = cv2.resize(original_image, (input_size, input_size))
        image_data = image_data / 255.

        images_data = []
        for i in range(1):
            images_data.append(image_data)
        images_data = np.asarray(images_data).astype(np.float32)
        infer = saved_model_loaded.signatures['serving_default']
        batch_data = tf.constant(images_data)
        pred_bbox = infer(batch_data)
        for key, value in pred_bbox.items():
            boxes = value[:, :, 0:4]
            pred_conf = value[:, :, 4:]

        # run non max suppression on detections
        boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
            boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
            scores=tf.reshape(
                pred_conf,
                (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
            max_output_size_per_class=50,
            max_total_size=50,
            iou_threshold=prop['iou'],
            score_threshold=threshold)

        # format bounding boxes from normalized ymin, xmin, ymax, xmax ---> xmin, ymin, xmax, ymax
        original_h, original_w, _ = original_image.shape
        bboxes = utils.format_boxes(boxes.numpy()[0], original_h, original_w)

        # hold all detection data in one variable
        pred_bbox = [
            bboxes,
            scores.numpy()[0],
            classes.numpy()[0],
            valid_detections.numpy()[0]
        ]

        # read in all class names from config
        class_names = utils.read_class_names(cfg.YOLO.CLASSES)

        # by default allow all classes in .names file
        allowed_classes = list(class_names.values())

        # if count flag is enabled, perform counting of objects
        if prop['count']:
            # count objects found
            counted_classes = count_objects(pred_bbox,
                                            by_class=True,
                                            allowed_classes=allowed_classes)
            # loop through dict and print
            for key, value in counted_classes.items():
                print("Number of {}s: {}".format(key, value))
            image = utils.draw_bbox(original_image,
                                    pred_bbox,
                                    prop['info'],
                                    counted_classes,
                                    allowed_classes=allowed_classes)
        else:
            image = utils.draw_bbox(original_image,
                                    pred_bbox,
                                    prop['info'],
                                    allowed_classes=allowed_classes)

        image = Image.fromarray(image.astype(np.uint8))
        image = cv2.cvtColor(np.array(image), cv2.COLOR_BGR2RGB)
        return image, counted_classes, pred_bbox
Пример #10
0
from PIL import Image


ROOT = "./"
filename_darknet_weights=ROOT+'yolov3.weights'
filename_classes=ROOT+'coco.names'
filename_converted_weights = ROOT+'yolov3.tf'

# Flags are used to define several options for YOLO.
flags.DEFINE_string('classes', filename_classes, 'path to classes file')
flags.DEFINE_string('weights', filename_converted_weights, 'path to weights file')
flags.DEFINE_boolean('tiny', False, 'yolov3 or yolov3-tiny')
flags.DEFINE_integer('size', 416, 'resize images to')
flags.DEFINE_string('tfrecord', None, 'tfrecord instead of image')
flags.DEFINE_integer('num_classes', 80, 'number of classes in the model')
FLAGS([sys.argv[0]])

yolo = YoloV3(classes=FLAGS.num_classes)

# Load weights and classes
yolo.load_weights(FLAGS.weights).expect_partial()
print('weights loaded')
class_names = [c.strip() for c in open(FLAGS.classes).readlines()]
print('classes loaded')


def predict(img):
    arr = tf.expand_dims(img, 0)
    arr = transform_images(arr, FLAGS.size)
    FLAGS.yolo_score_threshold = 0.5
    boxes, scores, classes, nums = yolo(arr)
Пример #11
0
def main(_argv):

    if FLAGS.num_gpus > 0:  # only supports 1 GPU
        ctx = mx.gpu()
    else:
        ctx = mx.cpu()

    key_flags = FLAGS.get_key_flags_for_module(sys.argv[0])
    print('\n'.join(f.serialize() for f in key_flags))

    # are we using features or do we include the CNN?
    if FLAGS.feats_model is None:
        backbone_net = get_model(FLAGS.backbone, pretrained=True, ctx=ctx).features
        cnn_model = FrameModel(backbone_net, 11)  # hardcoded the number of classes
        if FLAGS.backbone_from_id:
            if os.path.exists(os.path.join('models', 'vision', 'experiments', FLAGS.backbone_from_id)):
                files = os.listdir(os.path.join('models', 'vision', 'experiments', FLAGS.backbone_from_id))
                files = [f for f in files if f[-7:] == '.params']
                if len(files) > 0:
                    files = sorted(files, reverse=True)  # put latest model first
                    model_name = files[0]
                    cnn_model.load_parameters(os.path.join('models', 'vision', 'experiments', FLAGS.backbone_from_id, model_name), ctx=ctx)
                    print('Loaded backbone params: {}'.format(os.path.join('models', 'vision', 'experiments', FLAGS.backbone_from_id, model_name)))
            else:
                raise FileNotFoundError('{}'.format(os.path.join('models', 'vision', 'experiments', FLAGS.backbone_from_id)))

        if FLAGS.freeze_backbone:
            for param in cnn_model.collect_params().values():
                param.grad_req = 'null'

        cnn_model = TimeDistributed(cnn_model.backbone)

        src_embed = cnn_model

        transform_test = transforms.Compose([
            transforms.Resize(FLAGS.data_shape + 32),
            transforms.CenterCrop(FLAGS.data_shape),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

    else:
        from mxnet.gluon import nn  # need to do this to force no use of Embedding on src
        src_embed = nn.HybridSequential(prefix='src_embed_')
        with src_embed.name_scope():
            src_embed.add(nn.Dropout(rate=0.0))

        transform_train = None
        transform_test = None

    # setup the data
    data_train = TennisSet(split='train', transform=transform_train, captions=True, max_cap_len=FLAGS.tgt_max_len,
                           every=FLAGS.every, feats_model=FLAGS.feats_model)
    data_val = TennisSet(split='val', transform=transform_test, captions=True, vocab=data_train.vocab,
                         every=FLAGS.every, inference=True, feats_model=FLAGS.feats_model)
    data_test = TennisSet(split='test', transform=transform_test, captions=True, vocab=data_train.vocab,
                          every=FLAGS.every, inference=True, feats_model=FLAGS.feats_model)

    test_tgt_sentences = data_test.get_captions(split=True)
    write_sentences(test_tgt_sentences, os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, 'test_gt.txt'))

    # load embeddings for tgt_embed
    if FLAGS.emb_file:
        word_embs = nlp.embedding.TokenEmbedding.from_file(file_path=os.path.join('data', FLAGS.emb_file))
        data_test.vocab.set_embedding(word_embs)

        input_dim, output_dim = data_test.vocab.embedding.idx_to_vec.shape
        tgt_embed = gluon.nn.Embedding(input_dim, output_dim)
        tgt_embed.initialize(ctx=ctx)
        tgt_embed.weight.set_data(data_test.vocab.embedding.idx_to_vec)
    else:
        tgt_embed = None

    # setup the model
    encoder, decoder = get_gnmt_encoder_decoder(cell_type=FLAGS.cell_type,
                                                hidden_size=FLAGS.num_hidden,
                                                dropout=FLAGS.dropout,
                                                num_layers=FLAGS.num_layers,
                                                num_bi_layers=FLAGS.num_bi_layers)
    model = NMTModel(src_vocab=None, tgt_vocab=data_test.vocab, encoder=encoder, decoder=decoder,
                     embed_size=FLAGS.emb_size, prefix='gnmt_', src_embed=src_embed, tgt_embed=tgt_embed)

    model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
    static_alloc = True
    model.hybridize(static_alloc=static_alloc)
    print(model)

    if os.path.exists(os.path.join('models', 'captioning', 'experiments', FLAGS.model_id)):
        files = os.listdir(os.path.join('models', 'captioning', 'experiments', FLAGS.model_id))
        files = [f for f in files if f[-7:] == '.params']
        if len(files) > 0:
            files = sorted(files, reverse=True)  # put latest model first
            model_name = files[0]
            if model_name == 'valid_best.params':
                model_name = files[1]
            model.load_parameters(os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, model_name), ctx=ctx)
            print('Loaded model params: {}'.format(os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, model_name)))

    # setup the beam search
    translator = BeamSearchTranslator(model=model, beam_size=FLAGS.beam_size,
                                      scorer=nlp.model.BeamSearchScorer(alpha=FLAGS.lp_alpha, K=FLAGS.lp_k),
                                      max_length=FLAGS.tgt_max_len + 100)
    print('Use beam_size={}, alpha={}, K={}'.format(FLAGS.beam_size, FLAGS.lp_alpha, FLAGS.lp_k))

    # run the training
    train_data_loader, val_data_loader, test_data_loader = get_dataloaders(data_train, data_val, data_test)

    # load and evaluate the best model
    if os.path.exists(os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, 'valid_best.params')):
        model.load_parameters(os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, 'valid_best.params'))

    preds_path = os.path.join('models', 'captioning', 'experiments', FLAGS.model_id, 'best_test_out.txt')
    if not os.path.exists(preds_path):
        _, test_translation_out = evaluate(test_data_loader, model, translator, data_train, ctx)
    else:
        test_translation_out = read_sentences(preds_path)

    str_ = ''
    nlgeval = NLGEval()
    metrics_dict = nlgeval.compute_metrics([[' '.join(sent) for sent in test_tgt_sentences]],
                                           [' '.join(sent) for sent in test_translation_out])

    for k, v in metrics_dict.items():
        str_ += ', test ' + k + '={:.4f}'.format(float(v))
    print(str_)

    write_sentences(test_translation_out, preds_path)
    def scanner(self):
        FLAGS(sys.argv)
        self.physical_devices = tf.config.experimental.list_physical_devices(
            'GPU')
        if len(self.physical_devices) > 0:
            tf.config.experimental.set_memory_growth(self.physical_devices[0],
                                                     True)

        if FLAGS.tiny:
            self.yolo = YoloV3Tiny(classes=FLAGS.num_classes)
        else:
            self.yolo = YoloV3(classes=FLAGS.num_classes)

        self.yolo.load_weights(FLAGS.weights)
        logging.info('weights loaded')

        self.class_names = [c.strip() for c in open(FLAGS.classes).readlines()]
        logging.info('classes loaded')

        times = []

        try:
            self.vid = cv2.VideoCapture((0))
        except:
            self.vid = cv2.VideoCapture(FLAGS.video)

        self.out = None

        if FLAGS.output:
            # by default VideoCapture returns float instead of int
            self.width = int(self.vid.get(cv2.CAP_PROP_FRAME_WIDTH))
            self.height = int(self.vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
            self.fps = int(self.vid.get(cv2.CAP_PROP_FPS))
            self.codec = cv2.VideoWriter_fourcc(*FLAGS.output_format)
            self.out = cv2.VideoWriter(FLAGS.output, self.codec, self.fps,
                                       (self.width, self.height))
        self.fps = 0.0
        self.count = 0

        a = True

        while a:
            _, self.img = self.vid.read()

            if self.img is None:
                logging.warning("Empty Frame")
                time.sleep(0.1)
                self.count += 1
                if self.count < 3:
                    continue
                else:
                    break

            self.img_in = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)
            self.img_in = tf.expand_dims(self.img_in, 0)
            self.img_in = transform_images(self.img_in, FLAGS.size)

            self.t1 = time.time()
            self.boxes, self.scores, self.classes, self.nums = self.yolo.predict(
                self.img_in)

            self.fps = (self.fps + (1. / (time.time() - self.t1))) / 2

            self.img, self.pname = draw_outputs(
                self.img, (self.boxes, self.scores, self.classes, self.nums),
                self.class_names)
            pname = self.pname
            print('in main funcion : ', self.pname)

            self.img = cv2.putText(self.img, "FPS: {:.2f}".format(self.fps),
                                   (0, 30), cv2.FONT_HERSHEY_COMPLEX_SMALL, 1,
                                   (0, 0, 255), 2)

            # draw_outputs(img, outputs, class_names)
            if FLAGS.output:
                self.out.write(self.img)

            cv2.namedWindow('Product Scanner')
            cv2.imshow('Product Scanner', self.img)

            if cv2.waitKey(100) & 0xFF == ord('e'):
                self.dbdata()
                print('destroying scanner window')
                cv2.destroyWindow('Product Scanner')
                a = False
Пример #13
0
def main():
    # Common
    flags.DEFINE_string("map_name", "CollectMineralShards", "Name of the map")
    flags.DEFINE_integer("screen_size", 84, "Feature screen size")
    flags.DEFINE_integer("minimap_size", 64, "Feature minimap size")
    flags.DEFINE_bool("visualize", False, "Show python visualisation")
    flags.DEFINE_integer(
        "save_replay_episodes", 500,
        "How often to save replays, in episodes. 0 to disable saving replays.")
    flags.DEFINE_string(
        "replay_dir", os.path.abspath("Replays"),
        "Directory to save replays, relative to the current working directory."
    )

    # Environment
    flags.DEFINE_string("env", "movement.MovementEnv",
                        "Which environment to use.")

    # Algo-specific settings
    flags.DEFINE_integer(
        "print_freq", 10,
        "How often training progress is printed, in episodes")  # 100
    flags.DEFINE_integer(
        "checkpoint_freq", 10000,
        "How often to checkpoint the model (in temporary directory), in steps"
    )  # 10000
    flags.DEFINE_integer("save_model_freq", 250000,
                         "How often to save the model, in steps")
    flags.DEFINE_integer(
        "num_stack_frames", 0,
        "Number of frames to stack together (memory optimisation). Set 0 to disable stacking."
    )

    # Algo hyperparameters
    flags.DEFINE_float("learning_rate", 1e-5,
                       "Learning rate for adam optimizer")  # 5e-4
    flags.DEFINE_integer("max_timesteps", 2000000, "Max timesteps")  # 100000
    flags.DEFINE_integer("buffer_size", 100000,
                         "Size of replay buffer")  # 50000
    flags.DEFINE_float(
        "exploration_fraction", 0.5,
        "Fraction of max_timesteps over which exploration rate is annealed"
    )  # 0.1
    flags.DEFINE_float("exploration_final_eps", 0.01,
                       "Final value of random action probability")  # 0.02
    flags.DEFINE_integer("train_freq", 4,
                         "How often the model is updated, in steps")  # 1
    flags.DEFINE_integer("learning_starts", 100000,
                         "How many steps before learning starts")  # 1000
    flags.DEFINE_float("gamma", 0.99, "Discount factor")  # 1.0
    flags.DEFINE_integer("target_network_update_freq", 500,
                         "How often the target network is updated")  # 500
    flags.DEFINE_bool("prioritized_replay", True,
                      "Whether prioritized replay is used")  # True

    FLAGS(sys.argv)
    print(sys.argv)
    global save_model_freq  # Make this global since it's checked every timestep
    save_model_freq = FLAGS.save_model_freq

    train()
Пример #14
0
def main(_argv):

    os.makedirs(os.path.join('models', 'captioning', 'experiments',
                             FLAGS.model_id),
                exist_ok=True)

    if FLAGS.num_gpus > 0:  # only supports 1 GPU
        ctx = mx.gpu()
    else:
        ctx = mx.cpu()

    # Set up logging
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = os.path.join('models', 'captioning', 'experiments',
                                 FLAGS.model_id, 'log.txt')
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)

    key_flags = FLAGS.get_key_flags_for_module(sys.argv[0])
    logging.info('\n'.join(f.serialize() for f in key_flags))

    # set up tensorboard summary writer
    tb_sw = SummaryWriter(log_dir=os.path.join(log_dir, 'tb'),
                          comment=FLAGS.model_id)

    # are we using features or do we include the CNN?
    if FLAGS.feats_model is None:
        backbone_net = get_model(FLAGS.backbone, pretrained=True,
                                 ctx=ctx).features
        cnn_model = FrameModel(backbone_net,
                               11)  # hardcoded the number of classes
        if FLAGS.backbone_from_id:
            if os.path.exists(
                    os.path.join('models', 'vision', 'experiments',
                                 FLAGS.backbone_from_id)):
                files = os.listdir(
                    os.path.join('models', 'vision', 'experiments',
                                 FLAGS.backbone_from_id))
                files = [f for f in files if f[-7:] == '.params']
                if len(files) > 0:
                    files = sorted(files,
                                   reverse=True)  # put latest model first
                    model_name = files[0]
                    cnn_model.load_parameters(os.path.join(
                        'models', 'vision', 'experiments',
                        FLAGS.backbone_from_id, model_name),
                                              ctx=ctx)
                    logging.info('Loaded backbone params: {}'.format(
                        os.path.join('models', 'vision', 'experiments',
                                     FLAGS.backbone_from_id, model_name)))
            else:
                raise FileNotFoundError('{}'.format(
                    os.path.join('models', 'vision', 'experiments',
                                 FLAGS.backbone_from_id)))

        if FLAGS.freeze_backbone:
            for param in cnn_model.collect_params().values():
                param.grad_req = 'null'

        cnn_model = TimeDistributed(cnn_model.backbone)

        src_embed = cnn_model

        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(FLAGS.data_shape),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=0.4,
                                         contrast=0.4,
                                         saturation=0.4),
            transforms.RandomLighting(0.1),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

        transform_test = transforms.Compose([
            transforms.Resize(FLAGS.data_shape + 32),
            transforms.CenterCrop(FLAGS.data_shape),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        ])

    else:
        from mxnet.gluon import nn  # need to do this to force no use of Embedding on src
        src_embed = nn.HybridSequential(prefix='src_embed_')
        with src_embed.name_scope():
            src_embed.add(nn.Dropout(rate=0.0))

        transform_train = None
        transform_test = None

    # setup the data
    data_train = TennisSet(split='train',
                           transform=transform_train,
                           captions=True,
                           max_cap_len=FLAGS.tgt_max_len,
                           every=FLAGS.every,
                           feats_model=FLAGS.feats_model)
    data_val = TennisSet(split='val',
                         transform=transform_test,
                         captions=True,
                         vocab=data_train.vocab,
                         every=FLAGS.every,
                         inference=True,
                         feats_model=FLAGS.feats_model)
    data_test = TennisSet(split='test',
                          transform=transform_test,
                          captions=True,
                          vocab=data_train.vocab,
                          every=FLAGS.every,
                          inference=True,
                          feats_model=FLAGS.feats_model)

    val_tgt_sentences = data_val.get_captions(split=True)
    test_tgt_sentences = data_test.get_captions(split=True)
    write_sentences(
        val_tgt_sentences,
        os.path.join('models', 'captioning', 'experiments', FLAGS.model_id,
                     'val_gt.txt'))
    write_sentences(
        test_tgt_sentences,
        os.path.join('models', 'captioning', 'experiments', FLAGS.model_id,
                     'test_gt.txt'))

    # load embeddings for tgt_embed
    if FLAGS.emb_file:
        word_embs = nlp.embedding.TokenEmbedding.from_file(
            file_path=os.path.join('data', FLAGS.emb_file))
        data_train.vocab.set_embedding(word_embs)

        input_dim, output_dim = data_train.vocab.embedding.idx_to_vec.shape
        tgt_embed = gluon.nn.Embedding(input_dim, output_dim)
        tgt_embed.initialize(ctx=ctx)
        tgt_embed.weight.set_data(data_train.vocab.embedding.idx_to_vec)
    else:
        tgt_embed = None

    # setup the model
    encoder, decoder = get_gnmt_encoder_decoder(
        cell_type=FLAGS.cell_type,
        hidden_size=FLAGS.num_hidden,
        dropout=FLAGS.dropout,
        num_layers=FLAGS.num_layers,
        num_bi_layers=FLAGS.num_bi_layers)
    model = NMTModel(src_vocab=None,
                     tgt_vocab=data_train.vocab,
                     encoder=encoder,
                     decoder=decoder,
                     embed_size=FLAGS.emb_size,
                     prefix='gnmt_',
                     src_embed=src_embed,
                     tgt_embed=tgt_embed)

    model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
    static_alloc = True
    model.hybridize(static_alloc=static_alloc)
    logging.info(model)

    start_epoch = 0
    if os.path.exists(
            os.path.join('models', 'captioning', 'experiments',
                         FLAGS.model_id)):
        files = os.listdir(
            os.path.join('models', 'captioning', 'experiments',
                         FLAGS.model_id))
        files = [f for f in files if f[-7:] == '.params']
        if len(files) > 0:
            files = sorted(files, reverse=True)  # put latest model first
            model_name = files[0]
            if model_name == 'valid_best.params':
                model_name = files[1]
            start_epoch = int(model_name.split('.')[0]) + 1
            model.load_parameters(os.path.join('models', 'captioning',
                                               'experiments', FLAGS.model_id,
                                               model_name),
                                  ctx=ctx)
            logging.info('Loaded model params: {}'.format(
                os.path.join('models', 'captioning', 'experiments',
                             FLAGS.model_id, model_name)))

    # setup the beam search
    translator = BeamSearchTranslator(model=model,
                                      beam_size=FLAGS.beam_size,
                                      scorer=nlp.model.BeamSearchScorer(
                                          alpha=FLAGS.lp_alpha, K=FLAGS.lp_k),
                                      max_length=FLAGS.tgt_max_len + 100)
    logging.info('Use beam_size={}, alpha={}, K={}'.format(
        FLAGS.beam_size, FLAGS.lp_alpha, FLAGS.lp_k))

    # setup the loss function
    loss_function = MaskedSoftmaxCELoss()
    loss_function.hybridize(static_alloc=static_alloc)

    # run the training
    train(data_train, data_val, data_test, model, loss_function,
          val_tgt_sentences, test_tgt_sentences, translator, start_epoch, ctx,
          tb_sw)
Пример #15
0
def main(_argv):
    FLAGS.every = [int(s) for s in FLAGS.every]
    FLAGS.balance = [
        True if s.lower() == 'true' or s.lower() == 't' else False
        for s in FLAGS.balance
    ]
    FLAGS.lr_steps = [int(s) for s in FLAGS.lr_steps]

    if FLAGS.num_workers < 0:
        FLAGS.num_workers = multiprocessing.cpu_count()

    ctx = [mx.gpu(i) for i in range(FLAGS.num_gpus)
           ] if FLAGS.num_gpus > 0 else [mx.cpu()]

    # Set up logging
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    log_file_path = os.path.join('models', 'vision', 'experiments',
                                 FLAGS.model_id, 'log.txt')
    log_dir = os.path.dirname(log_file_path)
    if log_dir and not os.path.exists(log_dir):
        os.makedirs(log_dir)
    fh = logging.FileHandler(log_file_path)
    logger.addHandler(fh)

    key_flags = FLAGS.get_key_flags_for_module(sys.argv[0])
    logging.info('\n'.join(f.serialize() for f in key_flags))

    # set up tensorboard summary writer
    tb_sw = SummaryWriter(log_dir=os.path.join(log_dir, 'tb'),
                          comment=FLAGS.model_id)

    feat_sub_dir = None

    # Data augmentation, will do in dataset incase window>1 and need to be applied image-wise
    jitter_param = 0.4
    lighting_param = 0.1
    transform_train = None
    transform_test = None
    balance_train = True
    if FLAGS.feats_model is None:
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(FLAGS.data_shape),
            transforms.RandomFlipLeftRight(),
            transforms.RandomColorJitter(brightness=jitter_param,
                                         contrast=jitter_param,
                                         saturation=jitter_param),
            transforms.RandomLighting(lighting_param),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        transform_test = transforms.Compose([
            transforms.Resize(FLAGS.data_shape + 32),
            transforms.CenterCrop(FLAGS.data_shape),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        if bool(FLAGS.flow):

            transform_test = transforms.Compose([
                transforms.Resize(FLAGS.data_shape + 32),
                transforms.CenterCrop(FLAGS.data_shape),
                TwoStreamNormalize()
            ])

            transform_train = transform_test

    if FLAGS.save_feats:
        balance_train = False
        transform_train = transform_test

    if FLAGS.window > 1:
        transform_train = transform_test

    # Load datasets
    if FLAGS.temp_pool not in ['max', 'mean']:
        train_set = TennisSet(split='train',
                              transform=transform_train,
                              every=FLAGS.every[0],
                              padding=FLAGS.padding,
                              stride=FLAGS.stride,
                              window=FLAGS.window,
                              model_id=FLAGS.model_id,
                              split_id=FLAGS.split_id,
                              balance=balance_train,
                              flow=bool(FLAGS.flow),
                              feats_model=FLAGS.feats_model,
                              save_feats=FLAGS.save_feats)

        logging.info(train_set)

        val_set = TennisSet(split='val',
                            transform=transform_test,
                            every=FLAGS.every[1],
                            padding=FLAGS.padding,
                            stride=FLAGS.stride,
                            window=FLAGS.window,
                            model_id=FLAGS.model_id,
                            split_id=FLAGS.split_id,
                            balance=False,
                            flow=bool(FLAGS.flow),
                            feats_model=FLAGS.feats_model,
                            save_feats=FLAGS.save_feats)

        logging.info(val_set)

    test_set = TennisSet(split='test',
                         transform=transform_test,
                         every=FLAGS.every[2],
                         padding=FLAGS.padding,
                         stride=FLAGS.stride,
                         window=FLAGS.window,
                         model_id=FLAGS.model_id,
                         split_id=FLAGS.split_id,
                         balance=False,
                         flow=bool(FLAGS.flow),
                         feats_model=FLAGS.feats_model,
                         save_feats=FLAGS.save_feats)

    logging.info(test_set)

    # Data Loaders
    if FLAGS.temp_pool not in ['max', 'mean']:
        train_data = gluon.data.DataLoader(train_set,
                                           batch_size=FLAGS.batch_size,
                                           shuffle=True,
                                           num_workers=FLAGS.num_workers)
        val_data = gluon.data.DataLoader(val_set,
                                         batch_size=FLAGS.batch_size,
                                         shuffle=False,
                                         num_workers=FLAGS.num_workers)
    test_data = gluon.data.DataLoader(test_set,
                                      batch_size=FLAGS.batch_size,
                                      shuffle=False,
                                      num_workers=FLAGS.num_workers)

    # Define Model
    model = None
    if FLAGS.feats_model is None:
        if FLAGS.backbone == 'rdnet':
            backbone_net = get_r21d(num_layers=34,
                                    n_classes=400,
                                    t=8,
                                    pretrained=True).features
        else:
            if FLAGS.flow == 'sixc':
                backbone_net = get_model(
                    FLAGS.backbone, pretrained=False
                ).features  # 6 channel input, don't want pretraind
            else:
                backbone_net = get_model(FLAGS.backbone,
                                         pretrained=True).features

        if FLAGS.flow in ['twos', 'only']:
            if FLAGS.flow == 'only':
                backbone_net = None
            flow_net = get_model(
                FLAGS.backbone, pretrained=True
            ).features  # todo orig exp was not pretrained flow
            model = TwoStreamModel(backbone_net, flow_net,
                                   len(train_set.classes))
        elif FLAGS.backbone == 'rdnet':
            model = FrameModel(backbone_net, len(train_set.classes), swap=True)
        else:
            model = FrameModel(backbone_net, len(train_set.classes))
    elif FLAGS.temp_pool in ['max', 'mean']:
        backbone_net = get_model(FLAGS.backbone, pretrained=True).features
        model = FrameModel(backbone_net, len(test_set.classes))
    if FLAGS.window > 1:  # Time Distributed RNN

        if FLAGS.backbone_from_id and model is not None:
            if os.path.exists(
                    os.path.join('models', 'vision', 'experiments',
                                 FLAGS.backbone_from_id)):
                files = os.listdir(
                    os.path.join('models', 'vision', 'experiments',
                                 FLAGS.backbone_from_id))
                files = [f for f in files if f[-7:] == '.params']
                if len(files) > 0:
                    files = sorted(files,
                                   reverse=True)  # put latest model first
                    model_name = files[0]
                    model.load_parameters(
                        os.path.join('models', 'vision', 'experiments',
                                     FLAGS.backbone_from_id, model_name))
                    logging.info('Loaded backbone params: {}'.format(
                        os.path.join('models', 'vision', 'experiments',
                                     FLAGS.backbone_from_id, model_name)))

        if FLAGS.freeze_backbone and model is not None:
            for param in model.collect_params().values():
                param.grad_req = 'null'

        if FLAGS.temp_pool in ['gru', 'lstm']:
            model = CNNRNN(model,
                           num_classes=len(test_set.classes),
                           type=FLAGS.temp_pool,
                           hidden_size=128)
        elif FLAGS.temp_pool in ['mean', 'max']:
            pass
        else:
            assert FLAGS.backbone == 'rdnet'  # ensure 3d net
            assert FLAGS.window in [8, 32]

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        model.initialize()

    num_channels = 3
    if bool(FLAGS.flow):
        num_channels = 6
    if FLAGS.feats_model is None:
        if FLAGS.window == 1:
            logging.info(
                model.summary(
                    mx.nd.ndarray.ones(shape=(1, num_channels,
                                              FLAGS.data_shape,
                                              FLAGS.data_shape))))
        else:
            logging.info(
                model.summary(
                    mx.nd.ndarray.ones(shape=(1, FLAGS.window, num_channels,
                                              FLAGS.data_shape,
                                              FLAGS.data_shape))))
    else:
        if FLAGS.window == 1:
            logging.info(model.summary(mx.nd.ndarray.ones(shape=(1, 4096))))
        elif FLAGS.temp_pool not in ['max', 'mean']:
            logging.info(
                model.summary(mx.nd.ndarray.ones(shape=(1, FLAGS.window,
                                                        4096))))

    model.collect_params().reset_ctx(ctx)
    model.hybridize()

    if FLAGS.save_feats:
        best_score = -1
        best_epoch = -1
        with open(
                os.path.join('models', 'vision', 'experiments', FLAGS.model_id,
                             'scores.txt'), 'r') as f:
            lines = f.readlines()
            lines = [line.rstrip().split() for line in lines]
            for ep, sc in lines:
                if float(sc) > best_score:
                    best_epoch = int(ep)
                    best_score = float(sc)

        logging.info('Testing best model from Epoch %d with score of %f' %
                     (best_epoch, best_score))
        model.load_parameters(
            os.path.join('models', 'vision', 'experiments', FLAGS.model_id,
                         "{:04d}.params".format(best_epoch)))
        logging.info('Loaded model params: {}'.format(
            os.path.join('models', 'vision', 'experiments', FLAGS.model_id,
                         "{:04d}.params".format(best_epoch))))

        for data, sett in zip([train_data, val_data, test_data],
                              [train_set, val_set, test_set]):
            save_features(model, data, sett, ctx)
        return

    start_epoch = 0
    if os.path.exists(
            os.path.join('models', 'vision', 'experiments', FLAGS.model_id)):
        files = os.listdir(
            os.path.join('models', 'vision', 'experiments', FLAGS.model_id))
        files = [f for f in files if f[-7:] == '.params']
        if len(files) > 0:
            files = sorted(files, reverse=True)  # put latest model first
            model_name = files[0]
            start_epoch = int(model_name.split('.')[0]) + 1
            model.load_parameters(os.path.join('models', 'vision',
                                               'experiments', FLAGS.model_id,
                                               model_name),
                                  ctx=ctx)
            logging.info('Loaded model params: {}'.format(
                os.path.join('models', 'vision', 'experiments', FLAGS.model_id,
                             model_name)))

    # Setup the optimiser
    trainer = gluon.Trainer(model.collect_params(), 'sgd', {
        'learning_rate': FLAGS.lr,
        'momentum': FLAGS.momentum,
        'wd': FLAGS.wd
    })

    # Setup Metric/s
    metrics = [
        Accuracy(label_names=test_set.classes),
        mx.metric.TopKAccuracy(5, label_names=test_set.classes),
        Accuracy(name='accuracy_no',
                 label_names=test_set.classes[1:],
                 ignore_labels=[0]),
        Accuracy(name='accuracy_o',
                 label_names=test_set.classes[0],
                 ignore_labels=list(range(1, len(test_set.classes)))),
        PRF1(label_names=test_set.classes)
    ]

    val_metrics = [
        Accuracy(label_names=test_set.classes),
        mx.metric.TopKAccuracy(5, label_names=test_set.classes),
        Accuracy(name='accuracy_no',
                 label_names=test_set.classes[1:],
                 ignore_labels=[0]),
        Accuracy(name='accuracy_o',
                 label_names=test_set.classes[0],
                 ignore_labels=list(range(1, len(test_set.classes)))),
        PRF1(label_names=test_set.classes)
    ]

    test_metrics = [
        Accuracy(label_names=test_set.classes),
        mx.metric.TopKAccuracy(5, label_names=test_set.classes),
        Accuracy(name='accuracy_no',
                 label_names=test_set.classes[1:],
                 ignore_labels=[0]),
        Accuracy(name='accuracy_o',
                 label_names=test_set.classes[0],
                 ignore_labels=list(range(1, len(test_set.classes)))),
        PRF1(label_names=test_set.classes)
    ]

    # Setup Loss/es
    loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

    if FLAGS.temp_pool not in ['max', 'mean']:
        model = train_model(model, train_set, train_data, metrics, val_set,
                            val_data, val_metrics, trainer, loss_fn,
                            start_epoch, ctx, tb_sw)

    # model training complete, test it
    if FLAGS.temp_pool not in ['max', 'mean']:
        mod_path = os.path.join('models', 'vision', 'experiments',
                                FLAGS.model_id)
    else:
        mod_path = os.path.join('models', 'vision', 'experiments',
                                FLAGS.feats_model)
    best_score = -1
    best_epoch = -1
    with open(os.path.join(mod_path, 'scores.txt'), 'r') as f:
        lines = f.readlines()
        lines = [line.rstrip().split() for line in lines]
        for ep, sc in lines:
            if float(sc) > best_score:
                best_epoch = int(ep)
                best_score = float(sc)

    logging.info('Testing best model from Epoch %d with score of %f' %
                 (best_epoch, best_score))
    model.load_parameters(
        os.path.join(mod_path, "{:04d}.params".format(best_epoch)))
    logging.info('Loaded model params: {}'.format(
        os.path.join(mod_path, "{:04d}.params".format(best_epoch))))

    if FLAGS.temp_pool in ['max', 'mean']:
        assert FLAGS.backbone_from_id or FLAGS.feats_model  # if we doing temporal pooling ensure that we have loaded a pretrained net
        model = TemporalPooling(model,
                                pool=FLAGS.temp_pool,
                                num_classes=0,
                                feats=FLAGS.feats_model != None)

    tic = time.time()
    _ = test_model(model,
                   test_data,
                   test_set,
                   test_metrics,
                   ctx,
                   vis=FLAGS.vis)

    if FLAGS.temp_pool not in ['max', 'mean']:
        str_ = 'Train set:'
        for i in range(len(train_set.classes)):
            str_ += '\n'
            for j in range(len(train_set.classes)):
                str_ += str(metrics[4].mat[i, j]) + '\t'
        print(str_)
    str_ = 'Test set:'
    for i in range(len(test_set.classes)):
        str_ += '\n'
        for j in range(len(test_set.classes)):
            str_ += str(test_metrics[4].mat[i, j]) + '\t'
    print(str_)

    str_ = '[Finished] '
    for metric in test_metrics:
        result = metric.get()
        if not isinstance(result, list):
            result = [result]
        for res in result:
            str_ += ', Test_{}={:.3f}'.format(res[0], res[1])
        metric.reset()

    str_ += '  # Samples: {}, Time Taken: {:.1f}'.format(
        len(test_set),
        time.time() - tic)
    logging.info(str_)