예제 #1
0
import torch

from mmodel import get_module
from mtrain.watcher import watcher

if __name__ == "__main__":

    torch.backends.cudnn.benchmark = True
    # name = input('model name:')
    name = "openbb"

    # import torchvision.datasets as ds

    # a = ds.ImageFolder(
    #     root="./_PUBLIC_DATASET_/" + 'VisDA' + "/" + "validation" + "/"
    # )
    # print(a.classes)
    # assert False

    try:
        param, A = get_module(name)

        if param.make_record:
            watcher.prepare_notes(name, param.tag)
        A.train_module()
    finally:
        watcher.to_json()

예제 #2
0
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
tf_port = 8008

import torch
from main_aid import TBHandler
from mmodel import get_module

if __name__ == "__main__":

    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True

    torch.manual_seed(000000)
    torch.cuda.manual_seed_all(000000)

    model_name = "TFDN"
    tb = TBHandler(model_name)
    param, model = get_module(model_name)

    try:
        model.writer = tb.get_writer()
        tb.star_shell_tb(tf_port)
        model.train_module()
    finally:
        tb.kill_shell_tb()
        raise