train_file = os.path.join(DATA_DIR, 'X_train.hkl') train_sources = os.path.join(DATA_DIR, 'sources_train.hkl') val_file = os.path.join(DATA_DIR, 'X_val.hkl') val_sources = os.path.join(DATA_DIR, 'sources_val.hkl') kitti_train = KITTI(train_file, train_sources, nt) kitti_val = KITTI(val_file, val_sources, nt) train_loader = DataLoader(kitti_train, batch_size=batch_size, shuffle=True) val_loader = DataLoader(kitti_val, batch_size=batch_size, shuffle=True) model = PredNet(R_channels, A_channels, output_mode='error') if torch.cuda.is_available(): print('Using GPU.') model.cuda() optimizer = torch.optim.Adam(model.parameters(), lr=lr) def lr_scheduler(optimizer, epoch): if epoch < num_epochs // 2: return optimizer else: for param_group in optimizer.param_groups: param_group['lr'] = 0.0001 return optimizer for epoch in range(num_epochs): optimizer = lr_scheduler(optimizer, epoch)
n_channels = args.n_channels img_height = args.img_height img_width = args.img_width # stack_sizes = eval(args.stack_sizes) # R_stack_sizes = eval(args.R_stack_sizes) # A_filter_sizes = eval(args.A_filter_sizes) # Ahat_filter_sizes = eval(args.Ahat_filter_sizes) # R_filter_sizes = eval(args.R_filter_sizes) stack_sizes = (n_channels, 48, 96, 192) R_stack_sizes = stack_sizes A_filter_sizes = (3, 3, 3) Ahat_filter_sizes = (3, 3, 3, 3) R_filter_sizes = (3, 3, 3, 3) prednet = PredNet(stack_sizes, R_stack_sizes, A_filter_sizes, Ahat_filter_sizes, R_filter_sizes, output_mode='error', data_format=args.data_format, return_sequences=True) print(prednet) prednet.cuda() assert args.mode == 'train' train(prednet, args)