예제 #1
0
def train(train_img_path, pths_path, batch_size, lr, decay, num_workers,
          epoch_iter, interval, pretained):
    file_num = len(os.listdir(train_img_path))
    trainset = custom_dataset(train_img_path)
    train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                   shuffle=True, num_workers=num_workers, drop_last=True)

    criterion = Loss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EAST()
    # TODO 可能是bug
    if os.path.exists(pretained):
        model.load_state_dict(torch.load(pretained))

    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=decay)
    # scheduler = lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.94)

    for epoch in range(epoch_iter):
        model.train()
        optimizer.step()
        epoch_loss = 0
        epoch_time = time.time()
        for i, (img, gt_map) in enumerate(train_loader):
            start_time = time.time()
            img, gt_map = img.to(device), gt_map.to(device)
            east_detect = model(img)
            inside_score_loss, side_vertex_code_loss, side_vertex_coord_loss = criterion(
                gt_map, east_detect)
            loss = inside_score_loss + side_vertex_code_loss + side_vertex_coord_loss

            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 10 == 0:
                print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format( \
                    epoch + 1, epoch_iter, i + 1, int(file_num / batch_size), time.time() - start_time, loss.item()))
                print(
                    "inside_score_loss: %f | side_vertex_code_loss: %f | side_vertex_coord_loss: %f"
                    % (inside_score_loss, side_vertex_code_loss,
                       side_vertex_coord_loss))
        print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(
            epoch_loss / int(file_num / batch_size),
            time.time() - epoch_time))
        print(time.asctime(time.localtime(time.time())))
        # print('=' * 50)
        if (epoch + 1) % interval == 0:
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                state_dict,
                os.path.join(
                    pths_path, cfg.train_task_id +
                    '_model_epoch_{}.pth'.format(epoch + 1)))
def train(train_root_path, pths_path, batch_size, lr, num_workers, epoch_iter, interval):
    trainset = custom_dataset(train_root_path)
    file_num = trainset.__len__()
    train_loader = data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)

    criterion = Loss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EAST(pretrained=False)
    model.load_state_dict(torch.load('/home/chen-ubuntu/Desktop/checks_dataset/pths/model_epoch_stamp_8.pth'))

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    optimizer.zero_grad()
    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter//2], gamma=0.1)

    for epoch in range(epoch_iter):
        model.train()
        epoch_loss = 0

        loss_plot = []
        bx = []
        '''
        for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
            start_time = time.time()
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), gt_geo.to(device), ignored_map.to(
                device)
            pred_score, pred_geo = model(img)
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo, ignored_map)

            epoch_loss += loss.item()
            loss.backward()
            if (i + 1) % 8 == 0:
                optimizer.step()
                optimizer.zero_grad()

            if (i + 1) % 100 == 0:
                print(
                    'Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(
                        epoch + 1, epoch_iter, i + 1, int(file_num / batch_size), time.time() - start_time,
                        loss.item()))

            if (i + 1) % 30 == 0:
                loss_plot.append(loss.item())
                bx.append(i + epoch * int(file_num / batch_size))
            plt.plot(bx, loss_plot, label='loss_mean', linewidth=1, color='b', marker='o',
                     markerfacecolor='green', markersize=2)
            plt.savefig(os.path.abspath('./labeled2.jpg'))
        
        print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(epoch_loss / int(file_num / batch_size),
                                                                  time.time() - epoch_time))
        print(time.asctime(time.localtime(time.time())))
        
        print('=' * 50)'''
        if epoch % interval == 0:
            validloss, validacc = valid(train_loader, model, criterion, device)
            state_dict = model.module.state_dict() if data_parallel else model.state_dict()
예제 #3
0
def main():
	config = Config()

	if os.path.exists(config.SAVE_PATH):
		shutil.rmtree(config.SAVE_PATH)
	os.makedirs(config.SAVE_PATH, exist_ok=True)

	trainF = open(os.path.join(config.SAVE_PATH, "train.csv"), 'w')
	testF = open(os.path.join(config.SAVE_PATH, "test.csv"), 'w')

	train_img_path = os.path.abspath('../ICDAR_2015/train_img')
	train_gt_path  = os.path.abspath('../ICDAR_2015/train_gt')
	val_img_path = os.path.abspath('../ICDAR_2015/test_img')
	val_gt_path  = os.path.abspath('../ICDAR_2015/test_gt')

	kwargs = {'num_workers': 2, 'pin_memory': True} if torch.cuda.is_available() else {}

	train_dataset = custom_dataset(train_img_path, train_gt_path)
	train_loader = data.DataLoader(train_dataset, batch_size=config.TRAIN_BATCH*len(device_list), \
									shuffle=True, drop_last=True, **kwargs)

	val_dataset = custom_dataset(val_img_path, val_gt_path)
	val_loader = data.DataLoader(val_dataset, batch_size=config.TRAIN_BATCH*len(device_list), \
									shuffle=True, drop_last=True, **kwargs)

	net = EAST()

	if torch.cuda.is_available():
		net = net.cuda(device=device_list[0])
		net = torch.nn.DataParallel(net, device_ids=device_list)

	optimizer = torch.optim.Adam(net.parameters(), lr=config.BASE_LR, weight_decay=config.WEIGHT_DECAY)

	for epoch in range(config.EPOCHS):
		train(net, epoch, train_loader, optimizer, trainF, config)
		test(net, epoch, val_loader, testF, config)
		if epoch != 0 and epoch % config.SAVE_INTERVAL == 0:
			torch.save({'state_dict': net.state_dict()}, os.path.join(os.getcwd(), config.SAVE_PATH, "laneNet{}.pth.tar".format(epoch)))
	trainF.close()
	testF.close()
	torch.save({'state_dict': net.state_dict()}, os.path.join(os.getcwd(),  config.SAVE_PATH, "finalNet.pth.tar"))
예제 #4
0
def train(train_img_path, train_gt_path, pths_path, batch_size, lr,
          num_workers, epoch_iter, interval):
    file_num = len(os.listdir(train_img_path))
    trainset = custom_dataset(train_img_path, train_gt_path)
    train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                      shuffle=True, num_workers=num_workers, drop_last=True)

    criterion = Loss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EAST()
    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[epoch_iter // 2],
                                         gamma=0.1)

    for epoch in range(epoch_iter):
        model.train()
        scheduler.step()
        epoch_loss = 0
        epoch_time = time.time()
        for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
            start_time = time.time()
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(
                device), gt_geo.to(device), ignored_map.to(device)
            pred_score, pred_geo = model(img)
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo,
                             ignored_map)

            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(\
                       epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item()))

        print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(
            epoch_loss / int(file_num / batch_size),
            time.time() - epoch_time))
        print(time.asctime(time.localtime(time.time())))
        print('=' * 50)
        if (epoch + 1) % interval == 0:
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                state_dict,
                os.path.join(pths_path,
                             'model_epoch_{}.pth'.format(epoch + 1)))
예제 #5
0
def train(config):
    tb_writer = SummaryWriter(config.out)

    train_dataset = ICDARDataSet(config.train_data_path)
    file_num = train_dataset.get_num_of_data()
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=config.train_batch_size,
                                   shuffle=True,
                                   num_workers=config.num_workers,
                                   drop_last=True)
    criterion = Loss()
    model = EAST()

    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[config.epoch // 2, config.epoch//2 +
    # config.epoch//4, config.epoch//2], gamma=0.1)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               mode='min',
                                               factor=0.1,
                                               patience=3,
                                               verbose=True,
                                               min_lr=1e-5)

    best_hmean = 0.0

    for epoch in range(config.epoch):
        model.train()
        epoch_loss = 0
        epoch_time = time.time()
        for i, (img, gt_score, gt_geo,
                ignored_map) in tqdm(enumerate(train_loader),
                                     desc='Training...'):
            img = img.to(device)
            gt_score, gt_geo, ignored_map = gt_score.to(device), gt_geo.to(
                device), ignored_map.to(device)
            pred_score, pred_geo = model(img)
            total_loss, classify_loss, angle_loss, iou_loss, geo_loss = criterion(
                gt_score, pred_score, gt_geo, pred_geo, ignored_map)

            tb_writer.add_scalar('train/loss', total_loss,
                                 epoch * len(train_dataset) + i)
            tb_writer.add_scalar('train/classify_loss', classify_loss,
                                 epoch * len(train_dataset) + i)
            tb_writer.add_scalar('train/angle_loss', angle_loss,
                                 epoch * len(train_dataset) + i)
            tb_writer.add_scalar('train/iou_loss', iou_loss,
                                 epoch * len(train_dataset) + i)
            tb_writer.add_scalar('train/geo_loss', geo_loss,
                                 epoch * len(train_dataset) + i)

            epoch_loss += total_loss.item()
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        epoch_loss = epoch_loss / int(file_num / config.train_batch_size)
        print('\n {} epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(
            epoch, epoch_loss,
            time.time() - epoch_time))
        print(time.asctime(time.localtime(time.time())))
        print('=' * 50)
        scheduler.step(epoch_loss)
        tb_writer.add_scalar('lr', get_lr(optimizer),
                             (epoch + 1) * len(train_dataset))

        _, eval_result = evaluate_batch(model, config)
        print(eval_result)
        tb_writer.add_scalar('train/hmean', eval_result['hmean'],
                             (epoch + 1) * len(train_dataset))
        tb_writer.add_scalar('train/precision', eval_result['precision'],
                             (epoch + 1) * len(train_dataset))
        tb_writer.add_scalar('train/recall', eval_result['recall'],
                             (epoch + 1) * len(train_dataset))

        if eval_result['hmean'] > best_hmean:
            best_hmean = eval_result['hmean']
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                state_dict,
                os.path.join(config.out,
                             'model_epoch_{}.pth'.format(epoch + 1)))
예제 #6
0
def train(train_img_path, train_gt_path, pths_path, batch_size, lr,
          num_workers, epoch_iter, interval):
    #数据处理
    #import pdb
    #pdb.set_trace()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    file_num = len(os.listdir(train_img_path))
    trainset = custom_dataset(train_img_path, train_gt_path)
    train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                      shuffle=True, num_workers=num_workers, drop_last=True)

    #模型实现
    model = EAST()
    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True
    model.to(device)

    #loss实现
    criterion = Loss()

    #[完善优化算法的调用]写出优化算法的
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    #定义学习策略
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[epoch_iter // 2],
                                         gamma=0.1)

    for epoch in range(epoch_iter):
        model.train()
        epoch_loss = 0
        epoch_time = time.time()
        # import pdb
        # pdb.set_trace()
        train_process = tqdm(train_loader)
        for i, (img, gt_score, gt_geo,
                ignored_map) in enumerate(train_process):
            start_time = time.time()
            #import pdb
            # pdb.set_trace()
            # print("start_time=%s"%(start_time))
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(
                device), gt_geo.to(device), ignored_map.to(device)

            # 使用模型
            pred_score, pred_geo = model(img)
            # 计算得到loss
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo,
                             ignored_map)

            epoch_loss += loss.item()

            # 利用loss求取梯度
            optimizer.zero_grad()
            loss.backward()

            #权重更新
            optimizer.step()

            train_process.set_description_str("epoch:{}".format(epoch + 1))
            train_process.set_postfix_str("batch_loss:{:.4f}".format(
                loss.item()))
            '''
			print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(\
              epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item()))
			'''

        scheduler.step()
        with open('train.csv', 'a') as f:
            f.write('epoch[{}]: epoch_loss is {:.8f}, epoch_time is {:.8f}\n'.
                    format(epoch + 1, epoch_loss / int(file_num / batch_size),
                           time.time() - epoch_time))
        # print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(epoch_loss/int(file_num/batch_size), time.time()-epoch_time))
        # print(time.asctime(time.localtime(time.time())))
        # print('='*50)
        if (epoch + 1) % interval == 0:
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                state_dict,
                os.path.join(pths_path,
                             'model_epoch_{}.pth'.format(epoch + 1)))
예제 #7
0
파일: train.py 프로젝트: alperkesen/EAST-1
def train(train_img_path, train_gt_path, pths_path, batch_size, lr,
          num_workers, epoch_iter, interval):
    file_num = len(os.listdir(train_img_path))
    trainset = custom_dataset(train_img_path, train_gt_path)
    train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                      shuffle=True, num_workers=num_workers, drop_last=True)

    test_img_path = os.path.abspath('../ICDAR_2015/test_img')
    test_gt_path = os.path.abspath('../ICDAR_2015/test_gt')

    file_num2 = len(os.listdir(test_img_path))
    testset = custom_dataset(test_img_path, test_gt_path)
    test_loader = data.DataLoader(testset, batch_size=batch_size, \
                                      shuffle=True, num_workers=num_workers, drop_last=True)

    criterion = Loss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EAST()
    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    try:
        print("(Continue) Loading east...")
        checkpoint = torch.load('./pths/east.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch_dict = checkpoint['epoch_loss']
        test_dict = checkpoint['test_loss']
        total_epoch = checkpoint['epoch']
        best_loss = checkpoint['best_loss']
        best_acc = checkpoint['best_acc']
    except FileNotFoundError:
        print("(Initialize) Loading east_vgg16...")
        model.load_state_dict(torch.load('./pths/east_vgg16.pth'))
        epoch_dict = dict()
        test_dict = dict()
        total_epoch = 0
        best_loss = float('inf')
        best_acc = 0

    print("Continue from epoch {}".format(total_epoch))
    print("Epoch_dict", epoch_dict)
    print("Test_dict", test_dict)
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[300],
                                         gamma=0.1)

    for epoch in range(epoch_iter):
        model.train()
        scheduler.step()
        epoch_loss = 0
        test_loss = 0
        epoch_time = time.time()
        for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
            start_time = time.time()
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(
                device), gt_geo.to(device), ignored_map.to(device)
            pred_score, pred_geo = model(img)
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo,
                             ignored_map)

            epoch_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            print('Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(\
                       epoch+1, epoch_iter, i+1, int(file_num/batch_size), time.time()-start_time, loss.item()))

        epoch_dict[total_epoch + epoch +
                   1] = (epoch_loss / int(file_num / batch_size), epoch_loss)
        print('epoch_loss is {:.8f}, epoch_time is {:.8f}, epoch_loss: {}'.
              format(epoch_loss / int(file_num / batch_size),
                     time.time() - epoch_time, epoch_loss))
        model_state_dict = model.module.state_dict(
        ) if data_parallel else model.state_dict()

        with torch.no_grad():
            for i, (img, gt_score, gt_geo,
                    ignored_map) in enumerate(test_loader):
                img, gt_score, gt_geo, ignored_map = img.to(
                    device), gt_score.to(device), gt_geo.to(
                        device), ignored_map.to(device)
                pred_score, pred_geo = model(img)
                loss = criterion(gt_score, pred_score, gt_geo, pred_geo,
                                 ignored_map)

                test_loss += loss.item()
                print('Epoch (test) is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(\
                            epoch+1, epoch_iter, i+1, int(file_num2/batch_size), time.time()-start_time, loss.item()))

        test_dict[total_epoch + epoch +
                  1] = (test_loss / int(file_num2 / batch_size), test_loss)
        print(
            'test_loss is {:.8f}, epoch_time is {:.8f}, test_loss: {}'.format(
                test_loss / int(file_num2 / batch_size),
                time.time() - epoch_time, test_loss))

        print(time.asctime(time.localtime(time.time())))
        print('=' * 50)
        if (epoch + 1) % interval == 0:
            torch.save(
                {
                    'epoch': total_epoch + epoch + 1,
                    'model_state_dict': model_state_dict,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch_loss': epoch_dict,
                    'test_loss': test_dict,
                    'best_loss': best_loss,
                    'best_acc': best_acc
                }, os.path.join(pths_path, 'east.pth'))

        if (total_epoch + epoch + 1) % 10 == 0:
            torch.save(
                {
                    'epoch': total_epoch + epoch + 1,
                    'model_state_dict': model_state_dict,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch_loss': epoch_dict,
                    'test_loss': test_dict,
                    'best_loss': best_loss,
                    'best_acc': best_acc
                },
                os.path.join(
                    pths_path,
                    'east_epoch_{}.pth'.format(total_epoch + epoch + 1)))

        if test_loss / int(file_num2 / batch_size) < best_loss:
            torch.save(
                {
                    'epoch': total_epoch + epoch + 1,
                    'model_state_dict': model_state_dict,
                    'optimizer_state_dict': optimizer.state_dict(),
                    'epoch_loss': epoch_dict,
                    'test_loss': test_dict,
                    'best_loss': best_loss,
                    'best_acc': best_acc
                }, os.path.join(pths_path, 'east_best_loss.pth'))
예제 #8
0
def train(train_img_path, train_gt_path, pths_path, batch_size, lr,
          num_workers, epoch_iter, interval):
    file_num = len(os.listdir(train_img_path))
    trainset = custom_dataset(train_img_path, train_gt_path)
    train_loader = data.DataLoader(trainset, batch_size=batch_size, \
                                      shuffle=True, num_workers=num_workers, drop_last=True)

    criterion = Loss()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = EAST(pretrained=False)
    model.load_state_dict(
        torch.load('/root/last_dataset/east_tmp_pths/east_model_9_0.2783.pth'))
    data_parallel = False

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    optimizer.zero_grad()
    #scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[epoch_iter//2], gamma=0.1)

    for epoch in range(epoch_iter):
        model.train()
        epoch_loss = 0
        epoch_time = time.time()

        loss_plot = []
        bx = []
        for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
            start_time = time.time()
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(
                device), gt_geo.to(device), ignored_map.to(device)
            pred_score, pred_geo = model(img)
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo,
                             ignored_map)

            epoch_loss += loss.item()
            loss.backward()
            if (i + 1) % 3:
                optimizer.step()
                optimizer.zero_grad()

            if (i + 1) % 100 == 0:
                print(
                    'Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'
                    .format(epoch + 1, epoch_iter, i + 1,
                            int(file_num / batch_size),
                            time.time() - start_time, loss.item()))
            '''
			if (i + 1) % 100 == 0:
				loss_plot.append(loss.item())
				bx.append(i + epoch * int(file_num / batch_size))
			plt.plot(bx, loss_plot, label='loss_mean', linewidth=1, color='b', marker='o',
					 markerfacecolor='green', markersize=2)
			plt.savefig(os.path.abspath('./labeled.jpg'))
			'''
        print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(
            epoch_loss / int(file_num / batch_size),
            time.time() - epoch_time))
        print(time.asctime(time.localtime(time.time())))
        print('=' * 50)
        if epoch % interval == 0:
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                state_dict,
                os.path.join(
                    pths_path, 'east_model_{}_{:.4f}.pth'.format(
                        epoch + 10, epoch_loss / int(file_num / batch_size))))
예제 #9
0
def train(train_img_path, train_gt_path, pths_path, batch_size, lr,
          num_workers, epoch_iter, interval):
    # import pdb
    # pdb.set_trace()

    # 加载数据
    file_num = len(os.listdir(train_img_path))
    trainset = custom_dataset(train_img_path, train_gt_path)
    train_loader = data.DataLoader(trainset,
                                   batch_size=batch_size,
                                   shuffle=True,
                                   num_workers=num_workers,
                                   drop_last=True)

    # 加载模型
    model = EAST()
    data_parallel = False

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # 设置loss
    criterion = Loss()

    # [完善优化算法的调用]写出优化算法
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # 定义学习策略, milestones is a list of epoch indices, and ust be increasing.
    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=[epoch_iter // 2],
                                         gamma=.1)

    for epoch in range(epoch_iter):
        model.train()
        # when epoch meets epoch_iter // 2,
        # this scheduler will schedule learning rate
        scheduler.step()

        epoch_loss = 0
        epoch_time = time.time()

        for i, (img, gt_score, gt_geo, ignored_map) in enumerate(train_loader):
            start_time = time.time()
            print("start_time=%s" % start_time)

            # import pdb
            # pdb.set_trace()

            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), \
                gt_geo.to(device), ignored_map.to(device)

            # 前向反馈
            pred_score, pred_geo = model(img)
            # 计算loss
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo,
                             ignored_map)
            epoch_loss += loss.item()

            # 反向传播,优化器梯度需先清零!
            optimizer.zero_grad()
            loss.backward()

            # 模型权重更新
            optimizer.step()

            print(
                'Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'
                .format(epoch + 1, epoch_iter, i + 1,
                        int(file_num / batch_size),
                        time.time() - start_time, loss.item()))

        print('epoch_loss is {:.8f}, epoch_time is {:.8f}'.format(
            epoch_loss / int(file_num / batch_size),
            time.time() - epoch_time))
        print(time.asctime(time.localtime(time.time())))
        print('=' * 50)

        # 每5个周期保存一下模型的权重
        if (epoch + 1) % interval == 0:
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                state_dict,
                os.path.join(pths_path,
                             'model_epoch_{}.pth'.format(epoch + 1)))
예제 #10
0
        epoch_loss /= n_mini_batches
        epoch_score_loss /= n_mini_batches
        epoch_geometry_loss /= n_mini_batches
        losses.append(epoch_loss)
        score_losses.append(epoch_score_loss)
        geometry_losses.append(epoch_geometry_loss)
        toc = time.time()
        elapsed_time = toc - tic
        message = "Epoch:{}/{}  ScoreLoss:{:.6f}  GeometryLoss:{:.6f}  Loss:{:.6f}  Duration:{}".format(
            e, epochs, epoch_score_loss, epoch_geometry_loss, epoch_loss,
            time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
        print(message)

        if e % save_step == 0:
            torch.save(model.state_dict(), model_file.format(str(e)))
            keep_n = 1
            file_to_delete = model_file.format(str(e - (keep_n * save_step)))
            if os.path.exists(file_to_delete):
                os.remove(file_to_delete)

        if use_slack and e % slack_epoch_step == 0:
            send_message(slack_client, slack_channel, message)

    loss_type = "score_loss"
    plt.figure()
    plt.plot(range(1, epochs + 1), score_losses, marker="o", linestyle="--")
    plt.xticks(range(1, epochs + 1))
    plt.xlabel("epochs")
    plt.ylabel(loss_type)
    plt.savefig(plot_file.format(loss_type))
예제 #11
0
def train(img_path, gt_path, pths_path, batch_size, lr, num_workers, epoch_iter, interval):
    img_files = [os.path.join(img_path, img_file) for img_file in sorted(os.listdir(img_path))]
    gt_files = [os.path.join(gt_path, gt_file) for gt_file in sorted(os.listdir(gt_path))]

    if len(img_files) != len(gt_files):
        print('dataset is wrong!')
        return

    np.random.seed(10)
    state = np.random.get_state()
    np.random.shuffle(img_files)
    np.random.set_state(state)
    np.random.shuffle(gt_files)

    segment = len(img_files)//10
    train_img_files = img_files[:segment*9]
    train_gt_files = gt_files[:segment*9]
    val_img_files = img_files[segment*9:]
    val_gt_files = gt_files[segment*9:]

    print('trainset: ', len(train_img_files))
    print('validset: ', len(val_img_files))

    trainset = custom_dataset(train_img_files, train_gt_files, transform=True)
    validset = custom_dataset(val_img_files, val_gt_files)

    train_loader = data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    valid_loader = data.DataLoader(validset, batch_size=1, shuffle=True, num_workers=num_workers, drop_last=True)

    train_num = len(train_img_files)

    model = EAST(pretrained=False)
    #model.load_state_dict(torch.load('/home/chen-ubuntu/Desktop/checks_dataset/pths/model_mode1_epoch_24.pth'))

    data_parallel = False
    if torch.cuda.device_count() > 1:
        print("Use", torch.cuda.device_count(), 'gpus')
        data_parallel = True
        model = nn.DataParallel(model)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    optimizer.zero_grad()

    batch_cnt = 0
    for epoch in range(epoch_iter):
        model.train()
        epoch_loss = 0

        for i, (img, gt_score, gt_geo, ignored_map, _) in enumerate(train_loader):
            batch_cnt += 1
            start_time = time.time()
            img, gt_score, gt_geo, ignored_map = img.to(device), gt_score.to(device), gt_geo.to(device), ignored_map.to(
                device)
            pred_score, pred_geo = model(img)
            loss = criterion(gt_score, pred_score, gt_geo, pred_geo, ignored_map)

            epoch_loss += loss.item()
            loss.backward()

            if (i + 1) % 8 == 0:
                optimizer.step()
                optimizer.zero_grad()

            if (i + 1) % 8 == 0:
                print(
                    'Epoch is [{}/{}], mini-batch is [{}/{}], time consumption is {:.8f}, batch_loss is {:.8f}'.format(
                        epoch + 1, epoch_iter, i + 1, int(train_num / batch_size), time.time() - start_time,
                        loss.item()))
                writer.add_scalar('data/train_loss', loss.item(), batch_cnt)

        if epoch % interval == 0:
            #validloss, validacc = valid(valid_loader, model, criterion, device)
            writer.add_scalar('data/valid_loss', validloss, batch_cnt)
            writer.add_scalar('data/valid_acc', validacc, batch_cnt)
            state_dict = model.module.state_dict() if data_parallel else model.state_dict()
            torch.save(state_dict, os.path.join(pths_path, 'model_epoch_{}_acc_{:.3f}.pth'.format(epoch + 1, validacc)))

        print('=' * 50)