コード例 #1
0
def main():
    # ------------------------------------------------ Training Phase ------------------------------------------------
    # image_files = random.sample(glob.glob('E:\\work\\pedestrian_crop_python_process\\Pedestrain_cropDB\\train\\0\\*.bmp'), 10)
    # image_files = random.sample(glob.glob('data/0.normal/*.bmp'), 10)
    # data_in = data_read(image_files)

    opt = Options().parse()
    opt.iwidth = map_x_size
    opt.iheight = map_y_size

    #---new--- depth for size
    ctinit = map_x_size
    while ctinit > 4:
        ctinit = ctinit / 2
    opt.ctinit = int(ctinit)
    #---new---

    opt.batchsize = 64
    opt.epochs = 1000
    opt.mask = 0  # 1: masking for simulation map
    opt.time = datetime.now()

    train_dataloader = load_data(
        './data/unsupervised/train/')  # path to trainset
    result_path = './results/{0}/'.format(
        opt.time)  # reconstructions durnig the training
    if not os.path.isdir(result_path):
        os.mkdir(result_path)

    # dataloader = load_data(opt, data_in)
    model = AAE_basic(opt, train_dataloader)
    model.train()
コード例 #2
0
ファイル: kernel.py プロジェクト: kerviasx/ganomaly-waibao
    def startTest(self,params):
        params['signalInfo'].emit(0, "开始检测...")
        dataset = params['modelName']  # 'cus_mnist_2'
        dataroot = params['path']  # 'E:\ProjectSet\Pycharm\WAIBAO\cus_mnist2'
        opt = Options().parse(dataset)
        opt.isTrain = False
        opt.load_weights = True
        opt.signal = params['signal']
        opt.signalInfo = params['signalInfo']
        opt.lr = self.modelsData[dataset]['opt']['lr']
        opt.nz = self.modelsData[dataset]['opt']['nz']
        opt.batchsize = self.modelsData[dataset]['opt']['batchsize']
        opt.dataroot = dataroot

        print(opt)

        self.modelTest = MyTest(opt, [self.modelsData[dataset]['minVal'], self.modelsData[dataset]['maxVal'], self.modelsData[dataset]['proline']])
        self.modelTest.start()
コード例 #3
0
ファイル: kernel.py プロジェクト: kerviasx/ganomaly-waibao
    def startTrain(self, params):
        params['signalInfo'].emit(0, "开始训练...")
        dataset = params['name'] #'cus_mnist_2'
        dataroot = params['path']  #'E:\ProjectSet\Pycharm\WAIBAO\cus_mnist2'

        opt = Options().parse(dataset)
        opt.signal = params['signal']
        opt.load_weights = False
        opt.signalInfo = params['signalInfo']
        opt.lr = params['-lr']
        opt.batchsize = params['-batchsize']
        opt.niter = params['-niter']
        opt.nz = params['-nz']
        opt.desc = params['info']
        opt.dataroot = dataroot
        # opt.isize = 128
        print(opt)

        self.modelTrain = MyModel(opt)
        self.modelTrain.start()
コード例 #4
0
ファイル: run.py プロジェクト: SDJustus/skip-ganomaly
def main():
    """ Training
    """
    opt = Options().parse()
    opt.print_freq = opt.batchsize
    seed(opt.manualseed)
    print("Seed:", str(torch.seed()))
    if opt.phase == "inference":
        opt.batchsize=1
    data = load_data(opt)
    model = load_model(opt, data)
    if opt.phase == "inference":
        model.inference()
    else:
        if opt.path_to_weights:
            model.test()
        else:
            train_start = time.time()
            model.train()
            train_time = time.time() - train_start
            print (f'Train time: {train_time} secs')
コード例 #5
0
ファイル: analyze.py プロジェクト: Xiaohui9607/AbnormalBDL
def main():
    """ Training
    """
    path = '/mnt/AbnormalResult/'

    exps = [os.path.join(path, '1_cifar/1_pairs_airplane_2/train')]

    for exp in exps:
        optfile = os.path.join(exp, 'opt.txt')
        opt = Options().parse_from_file(optfile)
        opt.batchsize = 64
        if opt.setting == 'mxn':
            model = model_mxn(opt)
        else:
            model = model_mpairs(opt)
        for iter in range(opt.niter):
            weight_path = {
                'net_G':
                sorted(
                    glob.glob(
                        os.path.join(exp, 'weights',
                                     'Net_G*_epoch_%d.pth*' % iter))),
                'net_D':
                sorted(
                    glob.glob(
                        os.path.join(exp, 'weights',
                                     'Net_D*_epoch_%d.pth*' % iter)))
            }
            if len(weight_path['net_D']) != opt.n_MC_Disc and len(
                    weight_path['net_G']) != opt.n_MC_Gen:
                continue
            try:
                model.load_weight(weight_path)
            except:
                continue
            print("{}_{}".format(opt.name, iter))
            model.compute_epoch(iter)
コード例 #6
0
                                             batch_size=64,
                                             shuffle=True,
                                             drop_last=True,
                                             num_workers=8)

    return dataloader


map_x_size = 64
map_y_size = 64
map_layer_num = 3

opt = Options().parse()
opt.iwidth = map_x_size
opt.iheight = map_y_size
opt.batchsize = 1
opt.ngpu = 0
opt.gpu_ids = -1

# ---new---
ctinit = map_x_size
while ctinit > 4:
    ctinit = ctinit / 2
opt.ctinit = int(ctinit)
# ---new---

# opt.mask = 1

model_saved = False
d_loss = None
g_loss = None