コード例 #1
0
ファイル: train.py プロジェクト: atimashov/cnn_live_training
def get_optimizer(old_params, model, optimizer = None):
	prev_lr, prev_wd, prev_do, prev_opt, flag = old_params
	if optimizer is None:
		if prev_do != 50:
			for i, layer in enumerate(model.features):
				if type(layer) == torch.nn.Dropout:
					model.features[i] = torch.nn.Dropout(prev_do / 100)
			for i, layer in enumerate(model.classifier):
				if type(layer) == torch.nn.Dropout:
					model.classifier[i] = torch.nn.Dropout(prev_do / 100)
		if prev_opt == 'Adam':
			optimizer = optim.Adam(model.parameters(), lr = prev_lr, weight_decay = prev_wd)
		else:
			optimizer = optim.SGD(model.parameters(), lr = prev_lr, weight_decay = prev_wd, momentum = 0.9, nesterov = True)
	else:
		lr, wd, do, opt, flag = get_params()
		if flag:
			return (lr, wd, do, opt, flag), model, optimizer

		if (lr != prev_lr) or (wd != prev_wd) or (do != prev_do) or (opt != prev_opt):
			if prev_do != do:
				for i, layer in enumerate(model.features):
					if type(layer) == torch.nn.Dropout:
						model.features[i] = torch.nn.Dropout(prev_do / 100)
				for i, layer in enumerate(model.classifier):
					if type(layer) == torch.nn.Dropout:
						model.classifier[i] = torch.nn.Dropout(prev_do / 100)
			prev_lr, prev_wd, prev_do, prev_opt = lr, wd, do, opt

			if prev_opt == 'Adam':
				optimizer = optim.Adam(model.parameters(), lr=prev_lr, weight_decay=prev_wd)
			else:
				optimizer = optim.SGD(model.parameters(), lr=prev_lr, weight_decay=prev_wd, momentum=0.9, nesterov=True)
	return (prev_lr, prev_wd, prev_do, prev_opt, flag), model, optimizer
コード例 #2
0
def get_data(request):
    filter = helper.get_params(request)
    result = data.get(filter)

    if not result:
        json = {}
    else:
        json = helper.parse_data(result)
    return json
コード例 #3
0
ファイル: validate.py プロジェクト: atabekmad/regcheck
def validate_response(response):

    """ Looks for an intersection between response and used_data (list of all generated data in data.py) """

    resp_values = helper.get_params(response).values()
    matches_num = len(list(set(resp_values).intersection(used_data)))
    print "\nused_data = ", used_data
    print "\nresp_values = ", resp_values

    if matches_num == 3:
        print "\nValidation result - Values do match"
    else:
        msg = "Validation failed. Updated data doesn't correspond to the real data. " + "Matched - " + str(matches_num)
        raise AssertionError(msg)
コード例 #4
0
ファイル: train.py プロジェクト: atimashov/cnn_live_training
	parser.add_argument('--dataset-path', type=str, default='data', help='path to dataset: image_net_10, corrosion_dataset')
	parser.add_argument('--n-print', type=int, default=50, help='how often to print')
	parser.add_argument('--n-epochs', type=int, default=1000, help='number of epochs')
	parser.add_argument('--batch-size', type=int, default=32, help='batch size')
	parser.add_argument('--transfer', type=str, default='False', help='transfer/full learning')
	parser.add_argument('--use-gpu', type=str, default='True', help='gpu/cpu')
	inputs = parser.parse_args()
	print(inputs)

	inputs.transfer = True if inputs.transfer == 'True' else False
	USE_GPU = True if inputs.use_gpu == 'True' else False
	dtype = torch.float32  # TODO: find out how it affects speed and accuracy
	device = torch.device('cuda:0' if USE_GPU and torch.cuda.is_available() else 'cpu')

	# run model
	model, n_classes = init_model(inputs)
	# waiting for the new input
	while True:
		params = get_params(start = True)
		if params is not None: break
		sleep(10)
	# create data loader
	data_train = MyDataset(root = '{}/{}'.format(os.getcwd(), inputs.dataset_path), train = True)
	data_val = MyDataset(root = '{}/{}'.format(os.getcwd(), inputs.dataset_path), train = False)
	data_loader = {
		'train': DataLoader(data_train, batch_size = inputs.batch_size, shuffle = True, num_workers = 6),
		'val': DataLoader(data_val, batch_size = inputs.batch_size, shuffle = True, num_workers = 6)
	}
	hist = train_my(data_loader, model, datetime.now(), epochs = inputs.n_epochs, params = params, device = device, n_print = inputs.n_print)

コード例 #5
0
    data = json.loads(page)
        
    for video in data:
        menu_link = video["Id"]
        menu_name = video["Name"]
        menu_img = video["ThumbnailURL"]
        parameters = {"channel":thisChannel,"action":"playVideo","link":menu_link}
        helper.addDirectoryItem(menu_name, parameters, menu_img, folder=False)

    helper.endOfDirectory()
       
def playVideo(videoPlayer):
    stream = brightcovePlayer.play(const, playerID, videoPlayer, publisherID, playerKey)
    
    rtmpbase = stream[1][0:stream[1].find("&")]
    playpath = stream[1][stream[1].find("&") + 1:]
    finalurl = rtmpbase + ' playpath=' + playpath
    
    helper.setResolvedUrl(finalurl)

params = helper.get_params()
if len(params) == 1:
    mainPage()
else:
    print params['action']
    if params['action'] == "showSubMenu":
        showSubMenu(params['link'])
    if params['action'] == "showVideos":
        showVideos(params['link'])
    if params['action'] == "playVideo":
        playVideo(params['link'])
コード例 #6
0
            playpath = " playpath=" + video.getAttribute("progurl")

            streamUrl = streamUrl[:streamUrl.find(
                "/", 8)] + swfUrl + app + pageUrl + flashVer + playpath
            print streamUrl

            helper.setResolvedUrl(streamUrl)
            #return False
            listItem = xbmcgui.ListItem(streamUrl, path=streamUrl)
            #listItem.setProperty("PlayPath", streamUrl);
            listItem.setProperty('IsPlayable', 'true')
            playlist.add(url=streamUrl, listitem=listItem)

        player = xbmc.Player()
        player.play(playlist, playerItem)
        return False


params = helper.get_params()
if len(params) == 1:
    mainPage()
else:
    if params['action'] == "subPageXml":
        subPageXml(urllib.unquote(params['link']))
    if params['action'] == "videoPage":
        videoPage(params['link'])
    if params['action'] == "videoPageXml":
        videoPageXml(urllib.unquote(params['link']))
    if params['action'] == "playVideo":
        playVideo(urllib.unquote(params['link']))