def main(): # define parameters num_class = 6 batch_size = 1 time_step = 32 cnn_feat_size = 256 # AlexNet gaze_size = 3 gaze_lstm_hidden_size = 64 gaze_lstm_projected_size = 128 # dataset_path = '../data/gaze_dataset' dataset_path = '../../gaze-net/gaze_dataset' img_size = (224, 224) time_skip = 2 # define model arch = 'alexnet' extractor_model = FeatureExtractor(arch=arch) extractor_model.features = torch.nn.DataParallel(extractor_model.features) extractor_model.cuda() # uncomment this line if using cpu extractor_model.eval() model = SpatialAttentionModel(num_class, cnn_feat_size, gaze_size, gaze_lstm_hidden_size, gaze_lstm_projected_size) model.cuda() # load model from checkpoint model = load_checkpoint(model) trainGenerator = gaze_gen.GazeDataGenerator(validation_split=0.2) train_data = trainGenerator.flow_from_directory(dataset_path, subset='training', crop=False, batch_size=batch_size, target_size= img_size, class_mode='sequence_pytorch', time_skip=time_skip) # small dataset, error using validation split val_data = trainGenerator.flow_from_directory(dataset_path, subset='validation', crop=False, batch_size=batch_size, target_size= img_size, class_mode='sequence_pytorch', time_skip=time_skip) # start predict for i in range(10): print("start a new interaction") # img_seq: (ts,224,224,3), gaze_seq: (ts, 3), ouput: (ts, 6) # [img_seq, gaze_seq], target = next(val_data) [img_seq, gaze_seq], target = next(train_data) restart = True predict(img_seq, gaze_seq, extractor_model, model, restart=restart) print(target) for j in range(img_seq.shape[0]): # predict(img_seq[j], gaze_seq[j], None, model, restart=restart) # print(target[j]) # restart = False img = img_seq[j,:,:,:] gazes = gaze_seq cv2.circle(img, (int(gazes[j,1]), int(gazes[j,2])), 10, (255,0,0),-1) cv2.imshow('ImageWindow', img) cv2.waitKey(33)
def main(): # define parameters TRAIN = True num_class = 6 batch_size = 1 # time_step = 32 epochs = 50 cnn_feat_size = 256 # AlexNet gaze_size = 3 gaze_lstm_hidden_size = 64 gaze_lstm_projected_size = 128 learning_rate = 0.0001 momentum = 0.9 weight_decay = 1e-4 eval_freq = 1 # epoch print_freq = 1 # iteration # dataset_path = '../data/gaze_dataset' dataset_path = '../../gaze-net/gaze_dataset' img_size = (224, 224) log_path = '../log' logger = Logger(log_path, 'spatial') # define model arch = 'alexnet' extractor_model = FeatureExtractor(arch=arch) extractor_model.features = torch.nn.DataParallel(extractor_model.features) extractor_model.cuda() # uncomment this line if using cpu extractor_model.eval() model = SpatialAttentionModel(num_class, cnn_feat_size, gaze_size, gaze_lstm_hidden_size, gaze_lstm_projected_size) model.cuda() # define loss and optimizer # criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss().cuda() optimizer = torch.optim.SGD(model.parameters(), learning_rate, momentum = momentum, weight_decay=weight_decay) # define generator trainGenerator = gaze_gen.GazeDataGenerator(validation_split=0.2) train_data = trainGenerator.flow_from_directory(dataset_path, subset='training', crop=False, batch_size=batch_size, target_size= img_size, class_mode='sequence_pytorch') # small dataset, error using validation split val_data = trainGenerator.flow_from_directory(dataset_path, subset='validation', crop=False, batch_size=batch_size, target_size= img_size, class_mode='sequence_pytorch') # val_data = train_data def test(train_data): [img_seq, gaze_seq], target = next(train_data) img = img_seq[100,:,:,:] img_gamma = adjust_contrast(img) imsave('contrast.jpg', img_gamma) imsave('original.jpg', img) # test(train_data) print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~") # img_seq: (ts,224,224,3), gaze_seq: (ts, 3), ouput: (ts, 6) # [img_seq, gaze_seq], output = next(train_data) # print("gaze data shape") # print(img_seq.shape) # print(gaze_seq.shape) # print(output.shape) # start Training para = {'bs': batch_size, 'img_size': img_size, 'num_class': num_class, 'print_freq': print_freq} if TRAIN: print("get into training mode") best_acc = 0 for epoch in range(epochs): adjust_learning_rate(optimizer, epoch, learning_rate) print('Epoch: {}'.format(epoch)) # train for one epoch train(train_data, extractor_model, model, criterion, optimizer, epoch, logger, para) # evaluate on validation set if epoch % eval_freq == 0 or epoch == epochs - 1: acc = validate(val_data, extractor_model, model, criterion, epoch, logger, para, False) is_best = acc > best_acc best_acc = max(acc, best_acc) save_checkpoint({ 'epoch': epoch + 1, 'arch': arch, 'state_dict': model.state_dict(), 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), }, is_best) else: model = load_checkpoint(model) print("get into testing and visualization mode") print("visualization for training data") vis_data_path = '../vis/train/' if not os.path.exists(vis_data_path): os.makedirs(vis_data_path) acc = validate(train_data, extractor_model, model, criterion, -1, \ logger, para, False, vis_data_path) print("visualization for validation data") vis_data_path = '../vis/val/' if not os.path.exists(vis_data_path): os.makedirs(vis_data_path) acc = validate(val_data, extractor_model, model, criterion, -1, \ logger, para, True, vis_data_path)