예제 #1
0
def trainBatch():
    data1 = train_iter1.next()
    data2 = train_iter2.next()
    cpu_images = torch.cat((data1[0], data2[0]), 0)
    cpu_texts1 = data1[1] + data2[1]
    cpu_texts2 = data1[3] + data2[3]

    utils.loadData(image, cpu_images)
    t1, l1 = converter.encode(cpu_texts1, scanned=True)
    utils.loadData(text1_ori, t1)
    utils.loadData(length_ori, l1)
    t2, l2 = converter.encode(cpu_texts2, scanned=True)
    utils.loadData(text2_ori, t2)

    N = len(cpu_texts1)
    if opt.LR is True:
        preds1, preds2 = MODEL(image,
                               length_ori,
                               text1_ori,
                               text2_ori,
                               cpu_texts=cpu_texts1)

        text1_new = text1_ori
        text2_new = text2_ori

        cost_pred1 = criterion(preds1, text1_new) / 2.0
        cost_pred2 = criterion(preds2, text2_new) / 2.0
        loss_pred_avg1.add(cost_pred1)
        loss_pred_avg2.add(cost_pred2)

        cost = cost_pred1 + cost_pred2
    else:
        preds1 = MODEL(image,
                       length_ori,
                       text1_ori,
                       None,
                       cpu_texts=cpu_texts1)

        text1_new = text1_ori

        cost_pred1 = criterion(preds1, text1_new)
        loss_pred_avg1.add(cost_pred1)

        cost = cost_pred1

    loss_avg.add(cost)
    MODEL.zero_grad()
    cost.backward()
    optimizer.step()

    return cost
예제 #2
0
파일: main.py 프로젝트: gao-ye/CodeShare
def train_batch():
    data = train_iter.next()
    cpu_images = data[0]
    cpu_labels = data[1]

    utils.loadData(image, cpu_images)
    utils.loadData(ori_label, cpu_labels)

    # print('ori_label.shape',ori_label.shape)

    preds = MODEL(image)
    # print('pred---', preds.shape)
    # print('label--', ori_label.shape)
    cost = criterion(preds, ori_label)
    # print('cost-------', cost)

    loss.add(cost)

    MODEL.zero_grad()
    cost.backward()
    optimizer.step()
예제 #3
0
파일: main.py 프로젝트: gao-ye/CodeShare
def Load_train_data(args):
	# Train data
	train_dataset_1 = dataset.lmdbDataset( args.alphabet,root=args.train_1, 
		transform=dataset.resizeNormalize((args.imgW, args.imgH)))
	assert train_dataset_1
	train_dataset = train_dataset_1
	
	if args.train_2!=None:
	train_dataset_2 = dataset.lmdbDataset( args.alphabet,root=args.train_2, 
		transform=dataset.resizeNormalize((args.imgW, args.imgH)))
	assert train_dataset_2
	train_dataset = torch.utils.data.ConcatDataset([train_dataset, train_dataset_2])

	if args.train_3!=None:
	train_dataset_3 = dataset.lmdbDataset( args.alphabet,root=args.train_3, 
		transform=dataset.resizeNormalize((args.imgW, args.imgH)))
	assert train_dataset_3
	train_dataset = torch.utils.data.ConcatDataset([train_dataset, train_dataset_3])
# 该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor
	# train_loader = torch.utils.data.DataLoader(
	# 	train_dataset, batch_size=args.batchSize,
	# 	shuffle=False,sampler=dataset.randomSequentialSampler(train_dataset, args.batchSize),
	# 	num_workers=int(args.workers))
	train_loader = torch.utils.data.DataLoader(
		train_dataset, batch_size=args.batchSize,
		shuffle=False,
		num_workers=int(args.workers))

	return train_loader

def Load_test_data(dataset_name):
	dataset = dataset.lmdbDataset( args.alphabet1,test=True,root=dataset_name, 
		transform=dataset.resizeNormalize((args.imgW, args.imgH)))

	test_loader = torch.utils.data.DataLoader(
			dataset, shuffle=False, batch_size=args.batchSize, num_workers=int(args.workers))

	return test_loader


def set_random_seed(random_seed):
	random.seed(random_seed)
	np.random.seed(random_seed)
	torch.manual_seed(random_seed)
	# print(random.seed)


def train_batch():
    data = train_iter.next()
    cpu_images = data[0]
    cpu_labels = data[1]
    
    utils.loadData(image, cpu_images)
	utils.loadData(ori_label, cpu_labels)

    label = utils.label_convert(ori_label, nclass) ## 进行 one-hot 编码
	print("label size", label.shape)

    preds = MODEL(image)
    cost_pred = criterion(preds, label)
    cost = cost_pred
 
    loss_avg.add(cost)
    
    MODEL.zero_grad()
    cost.backward()
    optimizer.step()